-
Notifications
You must be signed in to change notification settings - Fork 2.7k
/
schedule.py
251 lines (229 loc) · 12.3 KB
/
schedule.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import sys
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Deque, List, Dict, Optional, Set, DefaultDict, Tuple
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, BinaryOps, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, GlobalCounters, flatten, getenv, merge_dicts, prod, dedup, all_int
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.shapetracker import ShapeTracker
# creation can recurse a lot
sys.setrecursionlimit(10000)
# TODO: it's unfortunate this needs to exist, but because of ASSIGN, we have to retain the LazyBuffer structure until post toposort
@dataclass(frozen=True)
class _LBScheduleItem:
ast: Tuple[LazyOp, ...]
outputs: Tuple[LazyBuffer, ...]
inputs: Tuple[LazyBuffer, ...]
var_vals: Dict[Variable, int]
# recursively create a lazyop
def _recursive_lazyop(buf:LazyBuffer, membufs:List[LazyBuffer], var_vals:Dict[Variable, int], st:ShapeTracker,
realizes:Set[LazyBuffer], cache, first=True, assign_to:Optional[LazyBuffer]=None, assign_idx:Optional[int]=None) -> LazyOp:
if (buf, st) in cache: return cache[(buf, st)]
if buf != buf.base:
st = buf.st + st
buf = buf.base
# all buffers here are base now
assert buf.op is not None
# consts are always fused and generated
if buf.op is LoadOps.CONST:
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
return LazyOp(BufferOps.CONST, (), ConstBuffer(buf.arg, buf.dtype, unbound_st))
# if we aren't fusing it, it's a load and we add it to the inputs
if buf.realized or (buf in realizes and not first):
unbound_st, st_var_vals = st.simplify().unbind()
var_vals.update(st_var_vals)
if assign_to is not None and buf is assign_to:
assert assign_idx is not None
if not unbound_st.contiguous:
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
raise RuntimeError(f"must be contiguous for assign {unbound_st}")
return LazyOp(BufferOps.LOAD, (), MemBuffer(assign_idx, buf.dtype, unbound_st))
if buf not in membufs: membufs.append(buf)
return LazyOp(BufferOps.LOAD, (), MemBuffer(membufs.index(buf), buf.dtype, unbound_st))
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
if buf.op is LoadOps.CONTIGUOUS:
assert first
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False)
if buf.op is LoadOps.ASSIGN:
assert first
assert buf.srcs[1].base is buf.srcs[1], "assign must be to base"
assert buf.srcs[1].realized is not None, f"assign must be already realized to schedule {buf.srcs[1]}"
return _recursive_lazyop(buf.srcs[0], membufs, var_vals, st, realizes, cache, False, assign_to=buf.srcs[1], assign_idx=membufs.index(buf))
# if it's a reduce, we have to change the shapetracker
if buf.op in ReduceOps:
assert st.contiguous, "ReduceOps late fusion must be contiguous"
st = ShapeTracker.from_shape(buf.srcs[0].shape)
# otherwise we fuse it like normal
cache[(buf, st)] = ret = \
LazyOp(buf.op, tuple(_recursive_lazyop(x, membufs, var_vals, st, realizes, cache, False, assign_to, assign_idx) for x in buf.srcs), buf.arg)
return ret
def _schedule_one(out:LazyBuffer, realizes:Set[LazyBuffer], reduce_for_op: Dict[LazyBuffer, LazyBuffer]) -> _LBScheduleItem:
inputs: List[LazyBuffer] = []
var_vals: Dict[Variable, int] = out.st.var_vals.copy()
if out.op in {LoadOps.CUSTOM, LoadOps.SYNC, LoadOps.COPY, LoadOps.EMPTY}:
op, inputs = LazyOp(out.op, (), out.arg), list(out.srcs)
else:
output_st, membufs = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape), [out]
op = _recursive_lazyop(out, membufs, var_vals, output_st, realizes, cache={})
op, inputs = LazyOp(BufferOps.STORE, (op, ), MemBuffer(0, out.dtype, output_st.simplify().unbind()[0])), membufs[1:]
return _LBScheduleItem((op,), (out,), tuple(inputs), var_vals)
def _schedule_outputs(outs:List[_LBScheduleItem], reduce_for_op:Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:
if len(outs) == 1: return ScheduleItem((r:=outs[0]).ast, (r.outputs[0].buffer,), tuple(x.buffer for x in r.inputs), r.var_vals)
# sort the outputs before fusing
outs = sorted(outs, key=lambda x: x.ast[0].src[0].key)
inputs, var_vals = {x: None for n in outs for x in n.inputs}, merge_dicts([n.outputs[0].st.var_vals.copy() for n in outs])
# recreate the multi output AST
ast: List[LazyOp] = []
for i, out in enumerate(outputs:=[x.outputs[0] for x in outs]):
output_st = ShapeTracker.from_shape(reduce_for_op[out].shape if out in reduce_for_op else out.shape)
op = _recursive_lazyop(out, outputs+list(inputs), var_vals, output_st, set(inputs), {})
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_st.simplify().unbind()[0])))
return ScheduleItem(tuple(ast), tuple(x.buffer for x in outputs), tuple(x.buffer for x in inputs), var_vals)
# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Set[LazyBuffer], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
if buf in allbufs or buf.base.realized: return
if GRAPH: log_lazybuffer(buf, scheduled)
if isinstance(buf.dtype, ImageDType) and (prod(buf.shape) != prod(buf.dtype.shape) or
not any(buf.shape[x]%4 == 0 for x in buf.st.unit_stride_axes())):
if DEBUG >= 3: print(f"forcing image {buf.dtype} with shape {buf.shape} to float32")
buf.dtype = dtypes.float32 # NOTE: this is what makes the dtype above not match
# hack the underlying buffer too
if buf.base is buf:
assert not hasattr(buf.buffer, '_buf'), "can't fixup allocated buffer"
buf.buffer.dtype = dtypes.float32
buf.buffer.options = None
if buf.base != buf:
# realize all places where the buffer is expanded
if prod(buf.base.st.shape) < prod(buf.st.shape):
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
else:
realizes.add(buf.base)
return _recurse_lb(buf.base, realizes, allbufs, simple_pads, children)
if buf.forced_realize: realizes.add(buf)
allbufs[buf] = None
if buf.op in LoadOps: realizes.add(buf.base)
if buf.op is LoadOps.COPY:
assert buf.srcs[0].st.contiguous and buf.srcs[0].size == buf.srcs[0].base.size, "can only copy contig"
realizes.add(buf.srcs[0].base)
for x in buf.srcs:
children[x.base][buf] = None
_recurse_lb(x, realizes, allbufs, simple_pads, children)
UNSAFE_PAD_OPS = {BinaryOps.DIV, BinaryOps.CMPLT, BinaryOps.CMPEQ, UnaryOps.LOG2, UnaryOps.EXP2}
def _is_padding_okay(buf:LazyBuffer, realizes:Set[LazyBuffer]) -> bool:
if buf in realizes or buf.realized: return True
# NOTE: this broke to_image_idx and coder with JIT
if buf.op in UNSAFE_PAD_OPS: return False
return all(_is_padding_okay(x.base, realizes) for x in buf.srcs)
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
if seen is None: seen = set()
# start by just realizing the buffers passed in
realizes: Set[LazyBuffer] = set([x.base for x in outs if not x.base.realized])
allbufs: Dict[LazyBuffer, None] = {}
simple_pads: Set[LazyBuffer] = set()
children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict)
for out in outs: _recurse_lb(out.base, realizes, allbufs, simple_pads, children, scheduled=True)
# check if we have to realize pads
for p in simple_pads:
if not _is_padding_okay(p, realizes):
realizes.add(p)
# find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child)
reduce_for_op: Dict[LazyBuffer, LazyBuffer] = {}
for r in allbufs.keys():
if r != r.base or r.op not in ReduceOps or r in realizes: continue
# follow the reduce down
child_set: Dict[LazyBuffer, ShapeTracker] = {r: r.st}
realized_children: Dict[LazyBuffer, ShapeTracker] = {}
forced_realize = False
can_chase = True
while not forced_realize and len(child_set):
next_child_set = {}
for tr,st in child_set.items():
if tr in realizes:
realized_children[tr] = st
# can only have one output buffer
# can only reduce contiguous
# max one reduceop per kernel
if len(realized_children) > 1 or not st.contiguous or st.size != r.st.size or (tr in reduce_for_op and reduce_for_op[tr] != r):
can_chase = tr not in reduce_for_op or reduce_for_op[tr] == r
forced_realize = True
break
continue
for tr_next in children[tr].keys():
if not tr_next.realized:
# max one reduceop per kernel
if tr_next.op in ReduceOps:
forced_realize = True
break
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1:
forced_realize = True
break
next_child_set[tr_next] = st + st_childs[0].st
child_set = next_child_set
if forced_realize:
tr = r
if can_chase:
# can chase this down to contiguous children
st = tr.st
while len(children[tr]) == 1:
tr_next = next(iter(children[tr].keys()))
st_childs = dedup([s for s in tr_next.srcs if s.base == tr])
if len(st_childs) > 1: break
if st.size != st_childs[0].st.size: break
st = st + st_childs[0].st
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
reduce_for_op[tr] = r
realizes.add(tr)
else:
assert len(realized_children) == 1
reduce_for_op[next(iter(realized_children.keys()))] = r
# preschedule all buffers in realizes
prescheduled = {x:_schedule_one(x, realizes, reduce_for_op) for x in realizes if x not in seen and x.realized is None and x.op is not LoadOps.CONST}
assign_targets = {x.srcs[1]:x for x in realizes if x.op is LoadOps.ASSIGN and x not in seen and x.realized is None}
# breadth first ordering
graph: DefaultDict[LazyBuffer, List[LazyBuffer]] = defaultdict(list)
in_degree: DefaultDict[LazyBuffer, int] = defaultdict(int)
for out, si in prescheduled.items():
for x in si.inputs:
graph[x].append(out)
if x in assign_targets:
graph[out].append(assign_targets[x])
in_degree[assign_targets[x]] += 1
if x in prescheduled: in_degree[out] += 1
queue: Deque[Tuple[int, LazyBuffer]] = deque((0, out) for out in prescheduled if in_degree[out] == 0)
output_groups: DefaultDict[Tuple, List[_LBScheduleItem]] = defaultdict(list)
while queue:
level, buf = queue.popleft()
seen.add(buf)
ps = prescheduled[buf]
# single output
if ps.ast[0].op is not BufferOps.STORE or buf.device.startswith("DISK") or buf.device == "METAL" or getenv("DISALLOW_MULTIOUT") \
or buf.op in ReduceOps or buf in reduce_for_op or buf.forced_realize: key: Tuple = (buf,)
# multi output
else: key = (level, buf.shape, buf.device)
output_groups[key].append(ps)
for x in graph[buf]:
in_degree[x] -= 1
if in_degree[x] == 0: queue.append((level+1, x))
kernel_number = GlobalCounters.kernel_count
schedule: List[ScheduleItem] = []
for group in output_groups.values():
if GRAPH:
kernel_number += 1
for ps in group: realized_lazybuffer(ps.outputs[0], kernel_number)
schedule.append(_schedule_outputs(group, reduce_for_op))
for ps in group: del ps.outputs[0].srcs # can only schedule once
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(flatten(s.outputs for s in schedule)):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
return schedule