Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] resnet batchnorm backward fusion spec #4370

Draft
wants to merge 36 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1f09880
start
chaosagent May 1, 2024
1de3453
add comments
chaosagent May 1, 2024
0142087
start on parallel reduce fusion spec
chaosagent May 2, 2024
10b7e6f
more cases
chaosagent May 2, 2024
b020f45
adjust subtrees of elementwise case
chaosagent May 2, 2024
7eb504e
test_preconv_e_fusion
chaosagent May 2, 2024
b51346c
use b instead of 3
chaosagent May 2, 2024
c259cac
whitespace
chaosagent May 2, 2024
02def56
I didn't know what Tensor.empty did lmao
chaosagent May 2, 2024
a6a7027
typo
chaosagent May 2, 2024
c5ba80c
add accounting
chaosagent May 2, 2024
c521e91
don't fuse if st don't match also for r_e
chaosagent May 2, 2024
f448c32
correct fake news
chaosagent May 2, 2024
f949536
Merge branch 'master' into bn_bw_sched_spec
Qazalin May 3, 2024
dc82999
is this how the linearizer works
chaosagent May 3, 2024
2012d88
Merge remote-tracking branch 'chaosagent/bn_bw_sched_spec' into bn_bw…
chaosagent May 3, 2024
9a1b1fb
different late asts
chaosagent May 3, 2024
45443f4
add different late ast
chaosagent May 3, 2024
745ff66
don't fuse if sums do not match?
chaosagent May 3, 2024
4cff10c
can there be too many accumulators for gemm?
chaosagent May 3, 2024
520d2eb
add todos
chaosagent May 5, 2024
fb82a6d
what the heck
chaosagent May 5, 2024
fe8bb23
defer reduces to match them
chaosagent May 6, 2024
35a5e94
also do the split queue logic for initial elements, this passes all m…
chaosagent May 6, 2024
feb72c6
diamond test
chaosagent May 6, 2024
338f137
permute test
chaosagent May 6, 2024
2d6de23
outputs+inputs instead of inputs+outputs
chaosagent May 6, 2024
0882eea
add todo
chaosagent May 6, 2024
ab7ed92
tricky crossing dag case
chaosagent May 6, 2024
e1e52a9
del
chaosagent May 6, 2024
6301d1b
comment
chaosagent May 6, 2024
f329d75
early/late cases
chaosagent May 6, 2024
7875b26
:P
chaosagent May 7, 2024
889d1fb
Revert ":P", it wasn't needed
chaosagent May 7, 2024
2b303e4
refactor
chaosagent May 8, 2024
4878be8
formatting
chaosagent May 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,32 @@ def test_multioutput(self):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]

def test_multioutput_parallel_r(self):
dtype, st, rst = dtypes.int, ShapeTracker.from_shape((8,2)), ShapeTracker.from_shape((1,2))
a = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtype, st=st))
b = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtype, st=st))
c = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=4, dtype=dtype, st=st))
const2 = LazyOp(BufferOps.CONST, tuple(), arg=2.0)
const3 = LazyOp(BufferOps.CONST, tuple(), arg=3.0)
sum1 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),))
sum2 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(a,c)),))
late1 = LazyOp(op=BinaryOps.ADD, src=(sum1, const2))
late2 = LazyOp(op=BinaryOps.ADD, src=(sum2, const3))
out0 = LazyOp(BufferOps.STORE, (late1,), MemBuffer(idx=0, dtype=dtype, st=rst))
out1 = LazyOp(BufferOps.STORE, (late2,), MemBuffer(idx=1, dtype=dtype, st=rst))

lin = Linearizer(out0, out1)
lin.linearize()

stores = [u for u in lin.uops if u.uop is UOps.STORE]
mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]]
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]

# todo: test different dtype outputs
# todo: test different output sts


def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.

Expand Down
139 changes: 137 additions & 2 deletions test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import List, Optional, Union
from tinygrad.tensor import Tensor
from tinygrad.ops import BinaryOps, LoadOps, ReduceOps
from tinygrad.helpers import DEBUG, GRAPH, flatten
from tinygrad.helpers import DEBUG, GRAPH, flatten, getenv
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.features.graph import print_tree, realized_lazybuffer
from tinygrad.engine.schedule import create_schedule
Expand All @@ -32,7 +32,8 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i, s in enumerate(sched):
print("kernel", i+1)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
assert len(sched) == allowed, f"{len(sched)}, {allowed}"
if getenv("SKIP_LIN"): return sched
# test the (non loadops) ops linearize
for s in sched:
if s.ast[0].op in LoadOps: continue
Expand Down Expand Up @@ -680,6 +681,140 @@ def test_prefer_half_buffer(self):
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase

def test_batchnorm_train_backward_fusion(self):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# we want to minimize number of passes over buffers of same size as x
# start: 12 kernels (some extraneous from constructing this test case)
# easy case: merge 4 reduces in backward into 1
# double reduce case: merge stat calculations from 2 to 1 (be careful of long reduces!)
# sum(x - \bar{x}): one kernel just calculates this, can be eliminated
# pre-expand fusion: is it fast? -2 kernels possible, 1 fw, 1 bw
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this refer to E_2_16_64n1 +E_2048
(graph ref: https://tiny-tools-client.vercel.app/?id=f7b72a41bad14974970329924c89b2c0)
?
#4235 could do this, it won't because <LB METAL (2, 16, 8, 8) float (<UnaryOps.CAST: 3>, None)> is forced_realize. I think it breaks the API if we fuse a forced_realize parent with its child.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am referring to E2_16_64n1 (full forward with relu) and E2_16_64 (full backward through batchnorm). The first can be fused with the next conv, and the latter can be fused with the next backward conv. (E_2048 simulates the backward from the next layer, plus relu backward)

This test case does not have the convs to focus on batchnorm, so it cannot happen here. will add more cases.

# merge reduce into previous conv: -2 kernels on top of the above. requires big linearizer change.
# ideal case: the foward + backward pass is just 3 convs and some small reduces!
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)

def test_parallel_reduce_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)

# we want to fuse reduces where the inputs of "significant" size of one reduce
# are a superset of the significant inputs of another reduce

# === -4 memory passes ===

# should fuse reduces that share common input, indexing on larger inputs
check_schedule([(x + a).sum(), (x + b).sum()], 1)

# same as above, except one of the sums has 2 big buffers
check_schedule([(x + y + a).sum(), (x + b).sum()], 1)

# same as above, except both sums have 2 big buffers
check_schedule([(x + y + a).sum(), (x + y + b).sum()], 1)

# same as above, except with 3 sums
# do not necessarily require the (x, a) (x, b) (a, b) case
check_schedule([(x + y + a).sum(), (x + b).sum(), (y + b).sum()], 1)

# same as above, except with different late ASTs
check_schedule([(x + y + a).sum() * 2, (x + b).sum() * 3], 1)

# same as above, except with permuted STs
check_schedule([(x.permute(1, 0) + y + a).sum() * 2, (x.permute(1, 0) + b).sum() * 3], 1)

# for now, we only want to do this fusion when no significant inputs are added
# this is because adding significant inputs is not free when doing gemms, since
# there is limited L1 cache space. it is probably OK to fuse when there is at most
# one expand axis.
check_schedule([(x + y).sum(), (x + z).sum()], 2)

# pick someone to fuse into if there is ambiguity
check_schedule([(x + y).sum(), (x + z).sum(), (x + a).sum()], 2)

# don't fuse if shapetrackers do not match
check_schedule([(x + a).sum(), (x.permute(1, 0) + b).sum()], 2)

# don't fuse if sums do not match
check_schedule([(x + a).sum(axis=0), (x + b).sum(axis=1)], 2)
check_schedule([(x + a).sum(axis=0), (x + b).sum()], 2)

# maybe don't fuse if is 2-axis expand (gemm) and the early asts do not match?
# (because there might be too many accumulators)
# OK for now since we only fuse for bijective st

# match by input + ST and two shapes? start with contigouous input only, check shapes (should determine reduces)

# what if same input + st but one is early and another is late?
check_schedule([x.sum(0, keepdim=True) + a, (a + b).sum()], 2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a real-world case?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be... maybe if you have a bias weight and

out.sum(0) + bias -> next layer

(bias**2).sum() -> LARS

?


# don't fuse early reduce with elementwise!
check_schedule([x.sum(), x + y], 2)

# no fuse f(expand(reduce(x)), x) "diamonds"
# or at least it should be fused sequential instead of parallel
# fused group needs to be contiguous subDAG -- construct with toposort
sum1 = (x + a).sum()
check_schedule([sum1, (x + sum1).sum()], 2)
del sum1

# super tricky crossing dag case
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you think fusing this is faster?
Screenshot 2024-05-06 at 7 01 08 PM

Copy link
Contributor Author

@chaosagent chaosagent May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The (conservative) heuristic I am using is that this fusion should never add extra loads from bijective shapetrackers. If a shapetracker is bijective, then its size matches the full_shape of the kernel, and all non-bijective loads must be from smaller buffer(region)s. In the normal case, the non-bijective "small" buffers are from expands and are very small compared to the bijective ones (here it's 1/16), so adding these won't hurt. Here, fusing the diagonal will save 1 memory pass over a big buffer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In fact, for simple reduces like these from bijective shapetrackers, it should be fine to fuse many unrelated reduces. Simple reduces don't really need a lot of cache -- the cache really helps when you have expands like (1, a) * (b, 1), since you can do an nm-sized tile with only n + m loads.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this may even be a real world case -- consider x and y to be the forward outputs of different layers.

sum1 = (x + a).sum()
sum2 = (y + b).sum()
check_schedule([(y + sum1).sum(), (x + sum2).sum()], 3)
del sum1, sum2

# todo: test _merge_prescheduled or otherwise verify buffers indexes are correct

@unittest.skip("not useful for bn backward")
def test_parallel_r_e_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)

# do parallel fusion also for elementwise
check_schedule([(x + a).sum(), (x * b)], 1)

# do this also for *subtrees* of elementwise, when it will save memory bandwidth
stat = (x + z + a).sum(axis=0, keepdim=True)
check_schedule([stat, ((x + z + b) * stat + y)], 1)

# don't steal if it doesn't reduce mem bw
stat = (x + a).sum(axis=0, keepdim=True)
check_schedule([stat, ((x + b) * stat + y)], 2)

# don't fuse if shapetrackers do not match
check_schedule([(x + a).sum(), (x.permute(1, 0) + b)], 2)

def test_preconv_e_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
conv = nn.Conv2d(16, 16, 3)
conv.weight = Tensor.empty(conv.weight.shape)
conv.bias = Tensor.empty(conv.bias.shape)

# === -4 memory passes (2 dependent on fusing conv(a + b)) ===

# fuse when the input has 1 big buffer
check_schedule([conv(x + a)], 1)

# fuse when the input has 2 big buffer
# very annoying that bn backward needs to fuse 2 big buffer
check_schedule([conv(x + y + a)], 1)

# (for now) don't fuse when the input has 3 big buffer
check_schedule([conv(x + y + z)], 2)

def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
Expand Down
85 changes: 82 additions & 3 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import sys, pickle, atexit
from collections import defaultdict, deque
from dataclasses import dataclass
from typing import Tuple, List, Dict, Optional, Set, DefaultDict
from typing import Tuple, List, Dict, Optional, Set, DefaultDict, Deque
from tinygrad.ops import LoadOps, ScheduleItem, BufferOps, LazyOp, ReduceOps, ConstBuffer, MemBuffer, UNSAFE_PAD_OPS, UnaryOps
from tinygrad.features.graph import log_lazybuffer, realized_lazybuffer
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, GlobalCounters, prod, dedup, all_int, merge_dicts, getenv, flatten
from tinygrad.shape.symbolic import Variable
from tinygrad.dtype import ImageDType, dtypes
from tinygrad.lazy import LazyBuffer
Expand Down Expand Up @@ -92,6 +92,18 @@ def _schedule_group(outs:Tuple[LazyBuffer, ...], realizes:Dict[LazyBuffer, None]
ast.append(LazyOp(BufferOps.STORE, (op, ), MemBuffer(i, out.dtype, output_view)))
return _LBScheduleItem(tuple(ast), outs, tuple(inputs), var_vals)

def _replace_bufis(ast: LazyOp, old_lbs: Tuple[LazyBuffer, ...], new_lbs: Tuple[LazyBuffer, ...]):
new_arg = MemBuffer(new_lbs.index(old_lbs[ast.arg.idx]), ast.arg.dtype, ast.arg.st) if ast.op in [BufferOps.LOAD, BufferOps.STORE] else ast.arg
return LazyOp(ast.op, tuple(_replace_bufis(x, old_lbs, new_lbs) for x in ast.src), new_arg)

def _merge_prescheduled(prescheduled: List[_LBScheduleItem]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've gone through this route in multioutput,

def _schedule_outputs(outs:List[_LBScheduleItem], reduce_for_op:Dict[LazyBuffer, LazyBuffer]) -> ScheduleItem:

I think you need to rebuild the entire AST.

# todo: need to toposort them somewhere
inputs: Tuple[LazyBuffer, ...] = tuple(dedup(flatten(psi.inputs for psi in prescheduled)))
outputs: Tuple[LazyBuffer, ...] = tuple(dedup(flatten(psi.outputs for psi in prescheduled)))
var_vals: Dict[Variable, int] = merge_dicts([psi.var_vals.copy() for psi in prescheduled])
ast: Tuple[LazyOp] = tuple(_replace_bufis(ast, psi.outputs+psi.inputs, outputs+inputs) for psi in prescheduled for ast in psi.ast)
return _LBScheduleItem(ast, outputs, inputs, var_vals)

# recursively search the entire graph for all LazyBuffers, insert realizes after expands
def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[LazyBuffer, None],
simple_pads:Set[LazyBuffer], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], scheduled=False):
Expand Down Expand Up @@ -236,6 +248,73 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
graph[key].append(assign)
in_degree[assign] += 1

def get_bijectives(csi):
bijectives = set()
for ast in csi.ast:
for lop in ast.lazyops:
if lop.op is BufferOps.LOAD:
membuf: MemBuffer = lop.arg
if membuf.st.shape == csi.outputs[0].st.shape or not membuf.st.bijective: continue # check if is earlybuf and bijective
bijectives.add(((csi.outputs + csi.inputs)[membuf.idx], membuf.st, csi.outputs[0].st))
return bijectives
pre_q: Deque[_LBScheduleItem] = deque()
reduce_collector: List[Tuple[_LBScheduleItem, Set[Tuple[LazyBuffer, ShapeTracker, ShapeTracker]]]] = []
for csi in (si for key, si in prescheduled.items() if in_degree[key] == 0):
bijectives = get_bijectives(csi)
if bijectives: reduce_collector.append((csi, bijectives))
else: pre_q.append(csi)
preschedule_groups = []
pre_in_deg = {k: v for k, v in in_degree.items()}
while pre_q or reduce_collector:
if pre_q:
ps = pre_q.popleft()
else:
ps, bij = reduce_collector.pop(0)
to_group = []
this_group = [ps]
best_bij = bij
for rci, (ps_, bij_) in enumerate(reduce_collector):
if best_bij <= bij_ or bij_ <= best_bij:
to_group.append(rci)
best_bij = best_bij | bij_
for rcoff, rci in enumerate(to_group):
ps_, bij_ = reduce_collector.pop(rci - rcoff)
this_group.append(ps_)
pre_q.append(ps_)
if len(this_group) > 1:
preschedule_groups.append(this_group)

to_enqueue: List[_LBScheduleItem] = []
for x in graph[ps.outputs[0]]:
pre_in_deg[x] -= 1
if pre_in_deg[x] == 0: to_enqueue.append(prescheduled[x])
# chase to children with ps as contiguous earlybuf, match by reduced shape
for csi in to_enqueue:
bijectives = get_bijectives(csi)
if bijectives: reduce_collector.append((csi, bijectives))
else: pre_q.append(csi)

# edit the graph with the new groupings
for lsigroup in preschedule_groups:
merged_lsi = _merge_prescheduled(lsigroup)
merged_edges, merged_in_deg = [], 0
for k, edges in graph.items():
new_edges = [lb for lb in edges if lb not in [lsi.outputs[0] for lsi in lsigroup]]
if len(new_edges) != len(edges):
assert k not in [lsi.outputs[0] for lsi in lsigroup], "cycle?"
new_edges += [merged_lsi.outputs[0]]
merged_in_deg += 1
graph[k] = new_edges

for k in [lsi.outputs[0] for lsi in lsigroup]:
merged_edges.extend(graph[k])
del graph[k]
del in_degree[k]
del prescheduled[k]
graph[merged_lsi.outputs[0]] = merged_edges
in_degree[merged_lsi.outputs[0]] = merged_in_deg
prescheduled[merged_lsi.outputs[0]] = merged_lsi

return graph, in_degree, prescheduled

SCHEDULES: List = []
Expand Down Expand Up @@ -269,7 +348,7 @@ def _save():
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
if DEBUG >= 1 and len(schedule) >= 10 or DEBUG >= 3: print(f"scheduled {len(schedule)} kernels")
return schedule, var_vals

def create_schedule(outs:List[LazyBuffer], seen:Optional[Set[LazyBuffer]]=None) -> List[ScheduleItem]:
Expand Down
5 changes: 4 additions & 1 deletion tinygrad/shape/shapetracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def from_shape(shape:Tuple[sint, ...]): return ShapeTracker((View.create(shape),
def contiguous(self) -> bool: return len(self.views) == 1 and self.views[0].contiguous

@property
def consecutive(self) -> bool: return len(self.views) == 1 and (v:=self.views[0]).mask is None and v.strides == strides_for_shape(v.shape)
def bijective(self) -> bool:
if len(self.views) != 1 or (v := self.views[0]).mask is not None: return False
s_strides, s_shape = zip(*sorted(zip(v.strides, v.shape), reverse=True))
return s_strides == strides_for_shape(s_shape)

@property
def shape(self) -> Tuple[sint, ...]: return self.views[-1].shape
Expand Down