Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[poc/wip] Parallel reduce linearizer changes -- full diff #4415

Closed
wants to merge 22 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,30 @@ def test_multioutput(self):
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]

def test_multioutput_parallel_r(self):
dtype, st, rst = dtypes.int, ShapeTracker.from_shape((8,2)), ShapeTracker.from_shape((1,2))
a = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=2, dtype=dtype, st=st))
b = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=3, dtype=dtype, st=st))
c = LazyOp(BufferOps.LOAD, arg=MemBuffer(idx=4, dtype=dtype, st=st))
const2 = LazyOp(BufferOps.CONST, tuple(), arg=ConstBuffer(2.0, dtype, st))
const3 = LazyOp(BufferOps.CONST, tuple(), arg=ConstBuffer(3.0, dtype, st))
sum1 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(a,b)),))
sum2 = LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.ADD, src=(a,c)),))
late1 = LazyOp(op=BinaryOps.ADD, src=(sum1, const2))
late2 = LazyOp(op=BinaryOps.ADD, src=(sum2, const3))
out0 = LazyOp(BufferOps.STORE, (late1,), MemBuffer(idx=0, dtype=dtype, st=rst))
out1 = LazyOp(BufferOps.STORE, (late2,), MemBuffer(idx=1, dtype=dtype, st=rst))

lin = Linearizer(out0, out1)
lin.linearize()
src = Device[Device.DEFAULT].compiler.render("test", lin.uops)
print(src)

stores = [u for u in lin.uops if u.uop is UOps.STORE]
mutable_bufs = [u for u in lin.uops if u.uop is UOps.DEFINE_GLOBAL and u.arg[-1]]
assert len(mutable_bufs) == len(stores) == 2
assert [u.arg[0] for u in mutable_bufs] == [0, 1]

def test_load_dedup(self):
# for different leaves in the AST, the same loads may occur.

Expand Down
113 changes: 112 additions & 1 deletion test/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def check_schedule(t:Union[Tensor, List[Tensor]], allowed:int, to_prerealize:Opt
for i, s in enumerate(sched):
print("kernel", i+1)
for op in s.ast: print_tree(op)
assert len(sched) == allowed
assert len(sched) == allowed, f"{len(sched)}, {allowed}"
# test the (non loadops) ops linearize
for s in sched:
if s.ast[0].op in LoadOps: continue
Expand Down Expand Up @@ -680,6 +680,117 @@ def test_prefer_half_buffer(self):
# sched = check_schedule([b, c], 4)
# doesn't store either in half because it doesn't chase

def test_batchnorm_train_backward_fusion(self):
with Tensor.train():
x = Tensor.empty((2, 16, 8, 8)).contiguous()
bn = nn.BatchNorm2d(16)
bn.weight.requires_grad = bn.bias.requires_grad = x.requires_grad = True
fw = bn(x).contiguous_backward().relu().contiguous()
fw.sum().backward()
# we want to minimize number of passes over buffers of same size as x
# start: 12 kernels (some extraneous from constructing this test case)
# easy case: merge 4 reduces in backward into 1
# double reduce case: merge stat calculations from 2 to 1 (be careful of long reduces!)
# sum(x - \bar{x}): one kernel just calculates this, can be eliminated
# pre-expand fusion: is it fast? -2 kernels possible, 1 fw, 1 bw
# merge reduce into previous conv: -2 kernels on top of the above. requires big linearizer change.
# ideal case: the foward + backward pass is just 3 convs and some small reduces!
check_schedule([x.grad, bn.weight.grad, bn.bias.grad, fw], 9)

def test_parallel_reduce_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)

# we want to fuse reduces where the inputs of "significant" size of one reduce
# are a superset of the significant inputs of another reduce

# === -4 memory passes ===

# should fuse reduces that share common input, indexing on larger inputs
check_schedule([(x + a).sum(), (x + b).sum()], 1)

# same as above, except one of the sums has 2 big buffers
check_schedule([(x + y + a).sum(), (x + b).sum()], 1)

# same as above, except both sums have 2 big buffers
check_schedule([(x + y + a).sum(), (x + y + b).sum()], 1)

# same as above, except with 3 sums
# do not necessarily require the (x, a) (x, b) (a, b) case
check_schedule([(x + y + a).sum(), (x + b).sum(), (y + b).sum()], 1)

# same as above, except with different late ASTs
check_schedule([(x + y + a).sum() * 2, (x + b).sum() * 3], 1)

# for now, we only want to do this fusion when no significant inputs are added
# this is because adding significant inputs is not free when doing gemms, since
# there is limited L1 cache space. it is probably OK to fuse when there is at most
# one expand axis.
check_schedule([(x + y).sum(), (x + z).sum()], 2)

# pick someone to fuse into if there is ambiguity
check_schedule([(x + y).sum(), (x + z).sum(), (x + a).sum()], 2)

# don't fuse if shapetrackers do not match
check_schedule([(x + a).sum(), (x.permute(1, 0) + b).sum()], 2)

# don't fuse if sums do not match
check_schedule([(x + a).sum(axis=0), (x + b).sum(axis=1)], 2)
check_schedule([(x + a).sum(axis=0), (x + b).sum()], 2)

# maybe don't fuse if is 2-axis expand (gemm) and the early asts do not match?
# (because there might be too many accumulators)

# what do to about fusing ImageDtype with other dtypes?

@unittest.skip("not useful for bn backward")
def test_parallel_r_e_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
b = Tensor.empty(1, 16)

# do parallel fusion also for elementwise
check_schedule([(x + a).sum(), (x * b)], 1)

# do this also for *subtrees* of elementwise, when it will save memory bandwidth
stat = (x + z + a).sum(axis=0, keepdim=True)
check_schedule([stat, ((x + z + b) * stat + y)], 1)

# don't steal if it doesn't reduce mem bw
stat = (x + a).sum(axis=0, keepdim=True)
check_schedule([stat, ((x + b) * stat + y)], 2)

# don't fuse if shapetrackers do not match
check_schedule([(x + a).sum(), (x.permute(1, 0) + b)], 2)

def test_preconv_e_fusion(self):
x = Tensor.empty(16, 16)
y = Tensor.empty(16, 16)
z = Tensor.empty(16, 16)
a = Tensor.empty(1, 16)
conv = nn.Conv2d(16, 16, 3)
conv.weight = Tensor.empty(conv.weight.shape)
conv.bias = Tensor.empty(conv.bias.shape)

# === -4 memory passes (2 dependent on fusing conv(a + b)) ===

# fuse when the input has 1 big buffer
check_schedule([conv(x + a)], 1)

# fuse when the input has 2 big buffer
# very annoying that bn backward needs to fuse 2 big buffer
check_schedule([conv(x + y + a)], 1)

# (for now) don't fuse when the input has 3 big buffer
check_schedule([conv(x + y + z)], 2)

# don't fuse downcast

def test_reduce_simple_chase(self):
a = Tensor.empty(4, 4, 4)
r = a.sum(0) + 6
Expand Down
28 changes: 13 additions & 15 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,15 @@ def __init__(self, *ast:LazyOp, opts:Optional[CompilerOptions]=None):
self.ast = ast
self.lazyops = flatten([op.lazyops for op in self.ast])

# there's only allowed to be one reduceop
reduceops = [x for x in self.lazyops if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
self.out_for_reduceop = {x: op for op in self.ast for x in op.lazyops if x.op in ReduceOps} # dedups
self.reduceops = list(self.out_for_reduceop.keys())

self.outbufs, self.vars = [x.arg for x in self.ast], flatten([x.vars() for x in self.ast])
loadops = [BufferOps.LOAD, BufferOps.CONST]
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = self.outbufs + dedup([x.arg for x in self.lazyops if x.op in loadops])

# get earlybufs, before the one reduce op
self.earlybufs = [x.arg for x in self.reduceop.lazyops if x.op in BufferOps] if self.reduceop else []
self.earlybufs = [x.arg for reduceop in self.reduceops for x in reduceop.lazyops if x.op in BufferOps]
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0

# create new shapetrackers inside this kernel, we will permute them
Expand Down Expand Up @@ -121,8 +119,8 @@ def copy(self):
ret.opts, ret.ast, ret.lazyops = self.opts, self.ast, self.lazyops

# things downstream of the AST
ret.reduceop, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
self.reduceop, self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
ret.reduceops, ret.outbufs, ret.vars, ret.bufs, ret.earlybufs, ret.full_buf_index = \
self.reduceops[:], self.outbufs, self.vars, [x for x in self.bufs if not isinstance(x, LocalBuffer)], self.earlybufs, self.full_buf_index
ret.sts = self.sts[:len(ret.bufs)] # NOTE: must redo the local buffers with TC in beam

# parameters for optimizations
Expand Down Expand Up @@ -335,12 +333,12 @@ def alias_buffer(self, i, pattern):
# ******************** high level optimizers ********************

def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
if use_tensor_cores and self.opts.has_local and len(self.reduceops) == 1 and self.reduceops[0].op is ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(self.reduceop.src[0].op is UnaryOps.CAST and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
if has_cast and not(self.reduceops[0].src[0].op is UnaryOps.CAST and self.reduceops[0].src[0].arg[0] == tc.dtype_out): continue

mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
mul_op = self.reduceops[0].src[0].src[0] if has_cast else self.reduceops[0].src[0]
if mul_op.op is not BinaryOps.MUL: continue

def buf_index(src: LazyOp) -> Optional[int]:
Expand Down Expand Up @@ -445,8 +443,8 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
if opt.op is not OptOps.PADTO: check(self.full_shape[axis] % amt == 0, "no longer valid shift")
else: amt = -1

if self.reduceop and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
if self.reduceops and (opt.op in {OptOps.GROUP, OptOps.GROUPTOP} or (self.group_for_reduces and opt.op not in {OptOps.NOLOCALS, OptOps.PADTO})):
acc_sz, upcast_idx = sum(dt.base.itemsize if isinstance((dt:=reduceop.dtype), ImageDType) else dt.itemsize for reduceop in self.reduceops), self.shape_len-self.upcasted
upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
check(amt*acc_sz*upcast_sz*local_sz <= self.opts.shared_max, "exceeds maximum shared memory size")
Expand Down Expand Up @@ -494,7 +492,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
check(not self.vars, "does not work with symbolic shape")
# ok to pad SUM if all parent ops have f(0) = 0
if self.first_reduce <= axis < self.shape_len - self.upcasted:
check(self.reduceop.op is ReduceOps.SUM and all(op.op not in UNSAFE_PAD_OPS for ops in self.reduceop.src for op in ops.lazyops), "cannot pad")
check(all(reduceop.op is ReduceOps.SUM and all(op.op not in UNSAFE_PAD_OPS for ops in reduceop.src for op in ops.lazyops) for reduceop in self.reduceops), "cannot pad")
padded = False
for i,st in enumerate(self.sts):
if self.sts[i].shape[axis] == 1: continue # reduced
Expand Down Expand Up @@ -522,8 +520,8 @@ def hand_coded_optimizations(self):
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4)
if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \
self.reduceop and self.reduceop.op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
self.reduceops and self.reduceops[0].op is ReduceOps.SUM and len(self.full_shape) >= 2 and self.opts.has_shared and \
(mulop:=self.reduceops[0].src[0]).op is BinaryOps.MUL and mulop.src[0].op is BufferOps.LOAD and mulop.src[1].op is BufferOps.LOAD:
st0, st1 = self.sts[self.bufs.index(mulop.src[0].arg)], self.sts[self.bufs.index(mulop.src[1].arg)]
strides0, strides1 = st0.real_strides(), st1.real_strides()
def has_expanded_axis(shape, strides): return any(s > 1 and st == 0 for s,st in zip(shape,strides))
Expand Down
Loading
Loading