-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[wip] resnet batchnorm backward fusion spec #4370
base: master
Are you sure you want to change the base?
Changes from all commits
1f09880
1de3453
0142087
10b7e6f
b020f45
7eb504e
b51346c
c259cac
02def56
a6a7027
c5ba80c
c521e91
f448c32
f949536
dc82999
2012d88
9a1b1fb
45443f4
745ff66
4cff10c
520d2eb
fb82a6d
fe8bb23
35a5e94
feb72c6
338f137
2d6de23
0882eea
ab7ed92
e1e52a9
6301d1b
f329d75
7875b26
889d1fb
2b303e4
4878be8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ | |
from typing import List, Optional, Union | ||
from tinygrad.tensor import Tensor | ||
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps | ||
from tinygrad.helpers import DEBUG, GRAPH, flatten | ||
from tinygrad.helpers import DEBUG, GRAPH, flatten, getenv | ||
from tinygrad.codegen.linearizer import Linearizer | ||
from tinygrad.features.graph import print_tree, realized_lazybuffer | ||
from tinygrad.engine.schedule import create_schedule | ||
|
@@ -32,7 +32,8 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt | |
for i, s in enumerate(sched): | ||
print("kernel", i+1) | ||
for op in s.ast: print_tree(op) | ||
assert len(sched) == allowed | ||
assert len(sched) == allowed, f"{len(sched)}, {allowed}" | ||
if getenv("SKIP_LIN"): return sched | ||
# test the (non loadops) ops linearize | ||
for s in sched: | ||
if s.ast[0].op in LoadOps: continue | ||
|
@@ -680,6 +681,140 @@ def test_prefer_half_buffer(self): | |
# sched = check_schedule([b, c], 4) | ||
# doesn't store either in half because it doesn't chase | ||
|
||
def test_batchnorm_train_backward_fusion(self): | ||
with Tensor.train(): | ||
x = Tensor.empty((2, 16, 8, 8)).contiguous() | ||
bn = nn.BatchNorm2d(16) | ||
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True | ||
fw = bn(x).contiguous_backward().relu().contiguous() | ||
fw.sum().backward() | ||
# we want to minimize number of passes over buffers of same size as x | ||
# start: 12 kernels (some extraneous from constructing this test case) | ||
# easy case: merge 4 reduces in backward into 1 | ||
# double reduce case: merge stat calculations from 2 to 1 (be careful of long reduces!) | ||
# sum(x - \bar{x}): one kernel just calculates this, can be eliminated | ||
# pre-expand fusion: is it fast? -2 kernels possible, 1 fw, 1 bw | ||
# merge reduce into previous conv: -2 kernels on top of the above. requires big linearizer change. | ||
# ideal case: the foward + backward pass is just 3 convs and some small reduces! | ||
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9) | ||
|
||
def test_parallel_reduce_fusion(self): | ||
x = Tensor.empty(16, 16) | ||
y = Tensor.empty(16, 16) | ||
z = Tensor.empty(16, 16) | ||
a = Tensor.empty(1, 16) | ||
b = Tensor.empty(1, 16) | ||
|
||
# we want to fuse reduces where the inputs of "significant" size of one reduce | ||
# are a superset of the significant inputs of another reduce | ||
|
||
# === -4 memory passes === | ||
|
||
# should fuse reduces that share common input, indexing on larger inputs | ||
check_schedule([(x + a).sum(), (x + b).sum()], 1) | ||
|
||
# same as above, except one of the sums has 2 big buffers | ||
check_schedule([(x + y + a).sum(), (x + b).sum()], 1) | ||
|
||
# same as above, except both sums have 2 big buffers | ||
check_schedule([(x + y + a).sum(), (x + y + b).sum()], 1) | ||
|
||
# same as above, except with 3 sums | ||
# do not necessarily require the (x, a) (x, b) (a, b) case | ||
check_schedule([(x + y + a).sum(), (x + b).sum(), (y + b).sum()], 1) | ||
|
||
# same as above, except with different late ASTs | ||
check_schedule([(x + y + a).sum() * 2, (x + b).sum() * 3], 1) | ||
|
||
# same as above, except with permuted STs | ||
check_schedule([(x.permute(1, 0) + y + a).sum() * 2, (x.permute(1, 0) + b).sum() * 3], 1) | ||
|
||
# for now, we only want to do this fusion when no significant inputs are added | ||
# this is because adding significant inputs is not free when doing gemms, since | ||
# there is limited L1 cache space. it is probably OK to fuse when there is at most | ||
# one expand axis. | ||
check_schedule([(x + y).sum(), (x + z).sum()], 2) | ||
|
||
# pick someone to fuse into if there is ambiguity | ||
check_schedule([(x + y).sum(), (x + z).sum(), (x + a).sum()], 2) | ||
|
||
# don't fuse if shapetrackers do not match | ||
check_schedule([(x + a).sum(), (x.permute(1, 0) + b).sum()], 2) | ||
|
||
# don't fuse if sums do not match | ||
check_schedule([(x + a).sum(axis=0), (x + b).sum(axis=1)], 2) | ||
check_schedule([(x + a).sum(axis=0), (x + b).sum()], 2) | ||
|
||
# maybe don't fuse if is 2-axis expand (gemm) and the early asts do not match? | ||
# (because there might be too many accumulators) | ||
# OK for now since we only fuse for bijective st | ||
|
||
# match by input + ST and two shapes? start with contigouous input only, check shapes (should determine reduces) | ||
|
||
# what if same input + st but one is early and another is late? | ||
check_schedule([x.sum(0, keepdim=True) + a, (a + b).sum()], 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is this a real-world case? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could be... maybe if you have a bias weight and out.sum(0) + bias -> next layer (bias**2).sum() -> LARS ? |
||
|
||
# don't fuse early reduce with elementwise! | ||
check_schedule([x.sum(), x + y], 2) | ||
|
||
# no fuse f(expand(reduce(x)), x) "diamonds" | ||
# or at least it should be fused sequential instead of parallel | ||
# fused group needs to be contiguous subDAG -- construct with toposort | ||
sum1 = (x + a).sum() | ||
check_schedule([sum1, (x + sum1).sum()], 2) | ||
del sum1 | ||
|
||
# super tricky crossing dag case | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The (conservative) heuristic I am using is that this fusion should never add extra loads from bijective shapetrackers. If a shapetracker is bijective, then its size matches the full_shape of the kernel, and all non-bijective loads must be from smaller buffer(region)s. In the normal case, the non-bijective "small" buffers are from expands and are very small compared to the bijective ones (here it's 1/16), so adding these won't hurt. Here, fusing the diagonal will save 1 memory pass over a big buffer. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In fact, for simple reduces like these from bijective shapetrackers, it should be fine to fuse many unrelated reduces. Simple reduces don't really need a lot of cache -- the cache really helps when you have expands like (1, a) * (b, 1), since you can do an nm-sized tile with only n + m loads. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i think this may even be a real world case -- consider x and y to be the forward outputs of different layers. |
||
sum1 = (x + a).sum() | ||
sum2 = (y + b).sum() | ||
check_schedule([(y + sum1).sum(), (x + sum2).sum()], 3) | ||
del sum1, sum2 | ||
|
||
# todo: test _merge_prescheduled or otherwise verify buffers indexes are correct | ||
|
||
@unittest.skip("not useful for bn backward") | ||
def test_parallel_r_e_fusion(self): | ||
x = Tensor.empty(16, 16) | ||
y = Tensor.empty(16, 16) | ||
z = Tensor.empty(16, 16) | ||
a = Tensor.empty(1, 16) | ||
b = Tensor.empty(1, 16) | ||
|
||
# do parallel fusion also for elementwise | ||
check_schedule([(x + a).sum(), (x * b)], 1) | ||
|
||
# do this also for *subtrees* of elementwise, when it will save memory bandwidth | ||
stat = (x + z + a).sum(axis=0, keepdim=True) | ||
check_schedule([stat, ((x + z + b) * stat + y)], 1) | ||
|
||
# don't steal if it doesn't reduce mem bw | ||
stat = (x + a).sum(axis=0, keepdim=True) | ||
check_schedule([stat, ((x + b) * stat + y)], 2) | ||
|
||
# don't fuse if shapetrackers do not match | ||
check_schedule([(x + a).sum(), (x.permute(1, 0) + b)], 2) | ||
|
||
def test_preconv_e_fusion(self): | ||
x = Tensor.empty(16, 16) | ||
y = Tensor.empty(16, 16) | ||
z = Tensor.empty(16, 16) | ||
a = Tensor.empty(1, 16) | ||
conv = nn.Conv2d(16, 16, 3) | ||
conv.weight = Tensor.empty(conv.weight.shape) | ||
conv.bias = Tensor.empty(conv.bias.shape) | ||
|
||
# === -4 memory passes (2 dependent on fusing conv(a + b)) === | ||
|
||
# fuse when the input has 1 big buffer | ||
check_schedule([conv(x + a)], 1) | ||
|
||
# fuse when the input has 2 big buffer | ||
# very annoying that bn backward needs to fuse 2 big buffer | ||
check_schedule([conv(x + y + a)], 1) | ||
|
||
# (for now) don't fuse when the input has 3 big buffer | ||
check_schedule([conv(x + y + z)], 2) | ||
|
||
def test_reduce_simple_chase(self): | ||
a = Tensor.empty(4, 4, 4) | ||
r = a.sum(0) + 6 | ||
|
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
@@ -1,10 +1,10 @@ | ||||
import sys, pickle, atexit | ||||
from collections import defaultdict, deque | ||||
from dataclasses import dataclass | ||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict | ||||
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Deque | ||||
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps | ||||
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer | ||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv | ||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv, flatten | ||||
from tinygrad.shape.symbolic import Variable | ||||
from tinygrad.dtype import ImageDType, dtypes | ||||
from tinygrad.lazy import LazyBuffer | ||||
|
@@ -92,6 +92,18 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None] | |||
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view))) | ||||
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals) | ||||
|
||||
def _replace_bufis(ast: LazyOp, old_lbs: Tuple[LazyBuffer, ...], new_lbs: Tuple[LazyBuffer, ...]): | ||||
new_arg = MemBuffer(new_lbs.index(old_lbs[ast.arg.idx]), ast.arg.dtype, ast.arg.st) if ast.op in [BufferOps.LOAD, BufferOps.STORE] else ast.arg | ||||
return LazyOp(ast.op, tuple(_replace_bufis(x, old_lbs, new_lbs) for x in ast.src), new_arg) | ||||
|
||||
def _merge_prescheduled(prescheduled: List[_LBScheduleItem]): | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've gone through this route in multioutput, tinygrad/tinygrad/engine/schedule.py Line 86 in 6c2cb8e
I think you need to rebuild the entire AST. |
||||
# todo: need to toposort them somewhere | ||||
inputs: Tuple[LazyBuffer, ...] = tuple(dedup(flatten(psi.inputs for psi in prescheduled))) | ||||
outputs: Tuple[LazyBuffer, ...] = tuple(dedup(flatten(psi.outputs for psi in prescheduled))) | ||||
var_vals: Dict[Variable, int] = merge_dicts([psi.var_vals.copy() for psi in prescheduled]) | ||||
ast: Tuple[LazyOp] = tuple(_replace_bufis(ast, psi.outputs+psi.inputs, outputs+inputs) for psi in prescheduled for ast in psi.ast) | ||||
return _LBScheduleItem(ast, outputs, inputs, var_vals) | ||||
|
||||
# recursively search the entire graph for all LazyBuffers, insert realizes after expands | ||||
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None], | ||||
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False): | ||||
|
@@ -236,6 +248,73 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul | |||
graph[key].append(assign) | ||||
in_degree[assign] += 1 | ||||
|
||||
def get_bijectives(csi): | ||||
bijectives = set() | ||||
for ast in csi.ast: | ||||
for lop in ast.lazyops: | ||||
if lop.op is BufferOps.LOAD: | ||||
membuf: MemBuffer = lop.arg | ||||
if membuf.st.shape == csi.outputs[0].st.shape or not membuf.st.bijective: continue # check if is earlybuf and bijective | ||||
bijectives.add(((csi.outputs + csi.inputs)[membuf.idx], membuf.st, csi.outputs[0].st)) | ||||
return bijectives | ||||
pre_q: Deque[_LBScheduleItem] = deque() | ||||
reduce_collector: List[Tuple[_LBScheduleItem, Set[Tuple[LazyBuffer, ShapeTracker, ShapeTracker]]]] = [] | ||||
for csi in (si for key, si in prescheduled.items() if in_degree[key] == 0): | ||||
bijectives = get_bijectives(csi) | ||||
if bijectives: reduce_collector.append((csi, bijectives)) | ||||
else: pre_q.append(csi) | ||||
preschedule_groups = [] | ||||
pre_in_deg = {k: v for k, v in in_degree.items()} | ||||
while pre_q or reduce_collector: | ||||
if pre_q: | ||||
ps = pre_q.popleft() | ||||
else: | ||||
ps, bij = reduce_collector.pop(0) | ||||
to_group = [] | ||||
this_group = [ps] | ||||
best_bij = bij | ||||
for rci, (ps_, bij_) in enumerate(reduce_collector): | ||||
if best_bij <= bij_ or bij_ <= best_bij: | ||||
to_group.append(rci) | ||||
best_bij = best_bij | bij_ | ||||
for rcoff, rci in enumerate(to_group): | ||||
ps_, bij_ = reduce_collector.pop(rci - rcoff) | ||||
this_group.append(ps_) | ||||
pre_q.append(ps_) | ||||
if len(this_group) > 1: | ||||
preschedule_groups.append(this_group) | ||||
|
||||
to_enqueue: List[_LBScheduleItem] = [] | ||||
for x in graph[ps.outputs[0]]: | ||||
pre_in_deg[x] -= 1 | ||||
if pre_in_deg[x] == 0: to_enqueue.append(prescheduled[x]) | ||||
# chase to children with ps as contiguous earlybuf, match by reduced shape | ||||
for csi in to_enqueue: | ||||
bijectives = get_bijectives(csi) | ||||
if bijectives: reduce_collector.append((csi, bijectives)) | ||||
else: pre_q.append(csi) | ||||
|
||||
# edit the graph with the new groupings | ||||
for lsigroup in preschedule_groups: | ||||
merged_lsi = _merge_prescheduled(lsigroup) | ||||
merged_edges, merged_in_deg = [], 0 | ||||
for k, edges in graph.items(): | ||||
new_edges = [lb for lb in edges if lb not in [lsi.outputs[0] for lsi in lsigroup]] | ||||
if len(new_edges) != len(edges): | ||||
assert k not in [lsi.outputs[0] for lsi in lsigroup], "cycle?" | ||||
new_edges += [merged_lsi.outputs[0]] | ||||
merged_in_deg += 1 | ||||
graph[k] = new_edges | ||||
|
||||
for k in [lsi.outputs[0] for lsi in lsigroup]: | ||||
merged_edges.extend(graph[k]) | ||||
del graph[k] | ||||
del in_degree[k] | ||||
del prescheduled[k] | ||||
graph[merged_lsi.outputs[0]] = merged_edges | ||||
in_degree[merged_lsi.outputs[0]] = merged_in_deg | ||||
prescheduled[merged_lsi.outputs[0]] = merged_lsi | ||||
|
||||
return graph, in_degree, prescheduled | ||||
|
||||
SCHEDULES: List = [] | ||||
|
@@ -269,7 +348,7 @@ def _save(): | |||
# confirm everything was scheduled correctly | ||||
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): | ||||
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}") | ||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels") | ||||
if DEBUG >= 1 and len(schedule) >= 10 or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels") | ||||
return schedule, var_vals | ||||
|
||||
def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]: | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this refer to
E_2_16_64n1
+E_2048
(graph ref: https://tiny-tools-client.vercel.app/?id=f7b72a41bad14974970329924c89b2c0)
?
#4235 could do this, it won't because
<LB METAL (2, 16, 8, 8) float (<UnaryOps.CAST: 3>, None)>
is forced_realize. I think it breaks the API if we fuse a forced_realize parent with its child.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am referring to E2_16_64n1 (full forward with relu) and E2_16_64 (full backward through batchnorm). The first can be fused with the next conv, and the latter can be fused with the next backward conv. (E_2048 simulates the backward from the next layer, plus relu backward)
This test case does not have the convs to focus on batchnorm, so it cannot happen here. will add more cases.