Skip to content

Commit

Permalink
linearizer: enable GROUP opts after TC
Browse files Browse the repository at this point in the history
fixes strides on group_for_reduce buffer to enable GROUP to be
used correctly after TC.  also reverses the local index order to
be first ones on the left.

fix tests that depend on local order and remove duplicated test
  • Loading branch information
flammit committed May 3, 2024
1 parent c736851 commit 57fe812
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 55 deletions.
41 changes: 13 additions & 28 deletions test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_reduce_upcast(self):
assert stores[0].vin[-1].dtype == accs[0].dtype == dtypes.float.vec(4)

def test_upcast_with_locals(self):
if not (opts:=Device[Device.DEFAULT].compiler.compiler_opts).has_local or not opts.has_shared or not opts.supports_float4:
if not (opts:=Device[Device.DEFAULT].compiler.compiler_opts).has_local or not opts.has_shared:
self.skipTest("device does not support upcasted reduce with locals")

x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
Expand All @@ -167,14 +167,19 @@ def test_upcast_with_locals(self):
k.linearize()

accs = [u for u in k.uops if u.uop is UOps.DEFINE_ACC]
assert len(accs) == 8

stores = [u for u in k.uops if u.uop is UOps.STORE]
assert len(stores) == 8

# the first store is to lds and can be upcasted
assert accs[0].dtype == stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[0].uop is UOps.DEFINE_LOCAL
# the second store is to gds with no upcasts
assert accs[1].dtype == stores[1].vin[-1].dtype == dtypes.float
assert stores[1].vin[0].uop is UOps.DEFINE_GLOBAL
# the first four stores is to lds with no upcasts
for i in range(4):
assert accs[i].dtype == stores[i].vin[-1].dtype == dtypes.float
assert stores[i].vin[0].uop is UOps.DEFINE_LOCAL
# the second four store is to gds with no upcasts
for i in range(4,8):
assert accs[i].dtype == stores[i].vin[-1].dtype == dtypes.float
assert stores[i].vin[0].uop is UOps.DEFINE_GLOBAL

def test_zero_fold(self):
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
Expand Down Expand Up @@ -991,26 +996,6 @@ def test_grouped_store_locals_and_globals(self):
assert barrier.vin == tuple(local_stores)
assert len([u for u in k.uops if u.uop is UOps.IF and u.vin[-1] == barrier]) == 1

def test_grouped_store_local_only(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.has_shared or \
not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
self.skipTest("Only Compiled uses linearizer with locals, shared, and float4")

x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
r = (x@y).relu()
k = Linearizer(*create_schedule([r.lazydata])[-1].ast)
k.hand_coded_optimizations()
k.linearize()

stores = [u for u in k.uops if u.uop is UOps.STORE]

# the float4 value stores directly in lds and we skip upcast
assert stores[0].vin[-1].dtype == dtypes.float.vec(4)
assert stores[0].vin[-1].uop is not UOps.CAST

# the global store doesn't change
assert stores[1].vin[-1].dtype == dtypes.float

def test_skip_unmatching_upcasts(self):
if not Device[Device.DEFAULT].compiler.compiler_opts.has_local or not Device[Device.DEFAULT].compiler.compiler_opts.supports_float4:
self.skipTest("Needs locals and float4")
Expand All @@ -1032,7 +1017,7 @@ def test_skip_unmatching_upcasts_with_gep(self):
self.skipTest("Needs locals and float4")
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(1, 8, 0, 0), offset=0, mask=None, contiguous=False),)))),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(8, 32, 1, 1), strides=(32, 1, 0, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501
opts = [Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.UPCAST, axis=2, amt=2), Opt(op=OptOps.LOCAL, axis=1, amt=8),
Opt(op=OptOps.UPCAST, axis=2, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=8),
Opt(op=OptOps.UPCAST, axis=1, amt=0), Opt(op=OptOps.UPCAST, axis=0, amt=2)]

k = Linearizer(ast)
Expand Down
25 changes: 14 additions & 11 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ def fix_axes(self, removed_axis:int): # adjust the TC axes if necesssary when an
elif removed_axis == self.axes[tc_dim]: self.axes_exist[tc_dim] = False

tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[4],[0],[2],[0],[-1, 1, 3],[0]], [[0],[3],[0],[1],[2, 4],[-1]], [[4],[3],[2],[1],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[2],[0],[0],[-1],[1]], [[0],[2],[1],[-1],[0]], [[-2],[2],[1],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[-2],[5],[0],[0],[-1,1,2,-3],[3,4]], [[5],[0],[0],[4],[3],[-1,1,2,-2],[0]], [[2],[-2],[5],[1],[-1],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)] if getenv("PTX") else [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])], # noqa: E501
"METAL": [TensorCore(dims=(8,8,8), threads=[(0,2),(1,4),(0,2),(1,2)], thread_local_sizes=[[2],[2],[2]], thread_local_aliases=[ [[0],[2],[0],[4],[-1, 1, 3],[0]], [[1],[0],[3],[0],[2, 4],[-1]], [[1],[2],[3],[4],[0],[-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.float, dtypes.float), (dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"HSA": [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], thread_local_sizes=[[16],[16],[4,2]], thread_local_aliases=[ [[0],[0],[2],[-1],[1]], [[1],[2],[0],[-1],[0]], [[1],[2],[-2],[0],[3,-1]] ], dtype_in=di, dtype_out=do) for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]], # noqa: E501
"CUDA": [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(0,2)], thread_local_sizes=[[2,2,2],[2,2],[2,2]], thread_local_aliases=[ [[0],[0],[5],[-2],[0],[-1,1,2,-3],[3,4]], [[3],[4],[0],[0],[5],[-1,1,2,-2],[0]], [[-1],[1],[5],[-2],[2],[0],[3,4]] ], dtype_in=di, dtype_out=do) for (di, do) in ([(dtypes.half, dtypes.float)] if getenv("PTX") else [(dtypes.half, dtypes.float), (dtypes.bfloat16, dtypes.float)])], # noqa: E501
}
tensor_cores["AMD"] = tensor_cores["HSA"]
tensor_cores["NV"] = tensor_cores["CUDA"]
Expand Down Expand Up @@ -316,22 +316,24 @@ def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
self.reshape_and_permute(None, order)
if DEBUG >= 3: print("permuted global dim", order, "due to allocation exceeds global limit")

def local_buffer(self, name:str, shape:Tuple[sint, ...], strides:Tuple[sint, ...], dtype:DType) -> LocalBuffer:
self.sts.append(ShapeTracker((View.create(shape, strides),)))
self.bufs.append((lb:=LocalBuffer(name=name, size=self.sts[-1].size, dtype=dtype)))
return lb

def alias_buffer(self, i, pattern):
assert len(pattern) == len(self.sts[i].shape), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"

bst = 1
real_strides = self.sts[i].real_strides()
shp, stride = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern)
shp, stride, bst = [(s if p != 0 else 1) for s,p in zip(self.sts[i].shape, pattern)], [0]*len(pattern), 1
for priority in range(1, max(pattern)+1): # priority. 0 is non local and ignored
for j,p in enumerate(pattern):
if priority == p and real_strides[j] != 0:
stride[j] = bst
bst *= shp[j]

self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size))
if DEBUG >= 4: print("aliasing buffer", self.sts[i])
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
self.local_alias[i] = self.local_buffer(f"ldata{i}", tuple(shp), tuple(stride), self.bufs[i].dtype)

# ******************** high level optimizers ********************

Expand Down Expand Up @@ -450,12 +452,13 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
acc_sz, upcast_idx = dt.base.itemsize if isinstance((dt:=self.reduceop.dtype), ImageDType) else dt.itemsize, self.shape_len-self.upcasted
upcast_sz = prod([a for a,b in zip(self.full_shape[upcast_idx:], self.sts[0].shape[upcast_idx:]) if a == b])
local_sz = prod(self.full_shape[self.first_reduce-self.local_dims:self.first_reduce+self.group_for_reduces])
check(amt*acc_sz*upcast_sz*local_sz <= self.opts.shared_max, "exceeds maximum shared memory size")
smem_sz = amt*acc_sz*upcast_sz*local_sz
check(smem_sz <= self.opts.shared_max, f"exceeds maximum shared memory size: needs {smem_sz}, max {self.opts.shared_max}")

if opt.op is OptOps.LOCAL: # cyan
check(self.opts.has_local, "target does not support local")
check(axis < self.global_dims, "local is for globals")
self.shift_to(axis, amt, insert_before=self.first_reduce-self.local_dims)
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op in {OptOps.GROUP, OptOps.GROUPTOP}: # green
check(self.opts.has_local and self.opts.has_shared, "target does not support local or shared mem")
Expand All @@ -475,7 +478,7 @@ def apply_opt(self, opt:Opt, append_opt:bool=True):
self.upcast()
elif opt.op is OptOps.UPCAST: # yellow
check(axis < self.first_reduce, "upcast is for non-reduce")
check(not(self.tensor_core and axis >= self.first_reduce-len(self.tensor_core.threads)), "can't upcast TC locals")
check(not(self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
check(amt <= 8, "don't upcast more than 8")
self.shift_to(axis, amt, insert_before=None)
self.upcast()
Expand Down
35 changes: 19 additions & 16 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,21 @@
from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType, ConstType
from tinygrad.helpers import colored, DEBUG, prod, getenv, to_function_name
from tinygrad.ops import LazyOp, UnaryOps, BinaryOps, TernaryOps, ReduceOps, ConstBuffer, MemBuffer, BufferOps
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, create_lt_node
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.features.image import to_image_idx

from tinygrad.codegen.uops import UOps, UOp, UOpGraph

def get_grouped_dims(prefix, start_dim, local_dims, maxdim:int=0):
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate(local_dims[0:maxdim-1] + (prod(local_dims[maxdim-1:]),) if len(local_dims) > maxdim else local_dims)] # noqa: E501
local_idxs = loop_local_idxs = [Variable(f"{prefix}{start_dim+i}", 0, s-1) for i,s in enumerate((prod(local_dims[:-(maxdim-1)]),) + local_dims[-(maxdim-1):] if len(local_dims) > maxdim else local_dims)] # noqa: E501
if maxdim != 0 and len(local_dims) > maxdim:
dd = local_idxs[maxdim-1]
dd = local_idxs[0]
nli = []
for s in local_dims[maxdim-1:][::-1]:
for s in local_dims[:-(maxdim-1)]:
nli.append(dd % s)
dd //= s
local_idxs = local_idxs[0:maxdim-1] + nli[::-1]
local_idxs = nli + local_idxs[-(maxdim-1):]
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]

def expand_idx(node:Node) -> Union[Variable, NumNode]: return next((v for v in node.vars() if v.expr.startswith("_uidx")), NumNode(0))
Expand Down Expand Up @@ -197,7 +196,7 @@ def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
full_var, full_var_sz = NumNode(0), 1
if alias[0] != 0:
for i in alias:
next_var = local_idxs[-i] if i > 0 else thread_idxs[-i-1]
next_var = local_idxs[i-1] if i > 0 else thread_idxs[-i-1]
full_var += next_var * full_var_sz
full_var_sz *= next_var.max+1
replace_idxs.append(full_var)
Expand All @@ -212,7 +211,7 @@ def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
min_alias_idx = min(self.local_alias.keys())
replace_input_idxs = calc_tc_idxs(tc.thread_local_sizes[i-min_alias_idx], tc.thread_local_aliases[i-min_alias_idx])
for n in range(len(tc.threads)):
buf_idxs[self.first_reduce-len(tc.threads)+n] = replace_input_idxs[n] # replace locals
buf_idxs[self.global_dims+n] = replace_input_idxs[n] # replace locals
for n in range(tc.num_upcasts()):
buf_idxs[self.shape_len-self.upcasted+n] = replace_input_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"{localbuf_idx} alias {i}: sts={self.sts[i]} idxs={buf_idxs}")
Expand All @@ -223,7 +222,7 @@ def calc_tc_idxs(local_sizes: List[int], aliases: List[List[int]]):
if (tc:=self.tensor_core):
replace_acc_idxs = calc_tc_idxs(tc.thread_local_sizes[2], tc.thread_local_aliases[2])
for n in range(len(tc.threads)):
local_idxs[self.local_dims-len(tc.threads)+n] = replace_acc_idxs[n] # replace locals
local_idxs[n] = replace_acc_idxs[n] # replace locals
for n in range(len(replace_acc_idxs)-len(tc.threads)):
upcast_idxs[n] = replace_acc_idxs[len(tc.threads)+n] # replace upcasts
if DEBUG >= 3: print(f"store alias: sts={self.sts[0]} idxs={global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs}")
Expand Down Expand Up @@ -326,7 +325,7 @@ def linearize(self):

# late alias the tensor core buffers
if (tc:=self.tensor_core) and (tc_opts:=self.tensor_core_opts):
alias_pattern = [0]*(self.global_dims+(self.local_dims-len(tc.threads))) + [2]*(len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1]*tc.num_upcasts() + [3]*(self.upcasted-tc.num_upcasts()) # noqa: E501
alias_pattern = [0]*(self.global_dims) + [2]*(len(tc.threads)) + [0]*(self.local_dims-len(tc.threads)) + [0]*(self.shape_len-self.upcasted-self.first_reduce) + [1,1] + [3]*(self.upcasted-2) # noqa: E501
for tc_buf in tc_opts.bufs:
self.alias_buffer(tc_buf, alias_pattern)

Expand Down Expand Up @@ -358,12 +357,16 @@ def linearize(self):
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uops.add(UOps.DEFINE_LOCAL,
PtrDType(dtypes.float32), (), (lb.name, self.sts[self.bufs.index(lb)].size))
# add a local buffer for multistage reduce. # TODO: use local alias
# add a local buffer for multistage reduce.
if self.group_for_reduces:
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker.from_shape(tuple([1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) # noqa: E501
temp_dtype = self.get_base_dtype(self.reduceop.dtype)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size, temp_dtype))
group_shape = [1] * self.global_dims + list(self.full_shape[self.global_dims:self.global_dims+self.local_dims+self.group_for_reduces]) + [1] * (self.shape_len - self.upcasted - self.group_for_reduces - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)] # noqa: E501
out_strides, group_strides, bst = [0 if x is None else x for x in self.sts[0].real_strides()], [1]*len(group_shape), 1
# group shared memory same strides as out, group_for_reduce axes on highest strides
group_axes = sorted([(i,sz,out_strides[i]) for i,sz in enumerate(group_shape) if sz>1 and out_strides[i]>0], key=lambda x: x[2])
for (i,sz,_) in group_axes + [(i,sz,out_strides[i]) for i,sz in enumerate(group_shape) if sz>1 and out_strides[i]==0]:
group_strides[i] = bst
bst *= sz
self.local_buffer("temp", tuple(group_shape), tuple(group_strides), (temp_dtype:=self.get_base_dtype(self.reduceop.dtype)))
self.buf_uops.append(self.uops.add(UOps.DEFINE_LOCAL, PtrDType(temp_dtype), (), ("temp", self.sts[-1].size)))

# kernel name (before late upcast)
Expand All @@ -388,9 +391,9 @@ def linearize(self):
self.global_size = [x.max+1 for x in loop_global_idxs][::-1]
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr.replace("gidx", "idx"), x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
elif self.opts.has_local:
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs][::-1]
self.global_size, self.local_size = [x.max+1 for x in loop_global_idxs][::-1], [x.max+1 for x in loop_local_idxs]
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_global_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_global_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (len(loop_local_idxs)-1-i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)}) # noqa: E501
self.loop_uops.update({x.expr:self.uops.add(UOps.SPECIAL, dtypes.int32, (), (i, x.expr, x.max+1)) for i,x in enumerate(loop_local_idxs)})
else:
self.render_loop(loop_global_idxs+loop_local_idxs)

Expand Down

0 comments on commit 57fe812

Please sign in to comment.