Skip to content

Commit

Permalink
CAST/BITCAST arg is the dtype only, no more tuple
Browse files Browse the repository at this point in the history
  • Loading branch information
wpmed92 committed May 11, 2024
1 parent 0dfe330 commit 99157dd
Show file tree
Hide file tree
Showing 7 changed files with 16 additions and 13 deletions.
2 changes: 1 addition & 1 deletion test/test_linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,7 +750,7 @@ def test_buf_index_not_found_tensor_core(self):
if Device.DEFAULT not in tensor_cores:
self.skipTest("No tensor cores for device")

ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.CMPEQ, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.int, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=dtypes.float), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=3, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1243, 256), strides=(1, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(0,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501
k = Linearizer(ast, opts=Device[Device.DEFAULT].renderer)
with self.assertRaises(KernelOptError):
k.apply_opt(Opt(OptOps.TC, 0, 1))
Expand Down
10 changes: 5 additions & 5 deletions test/test_linearizer_failures.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions tinygrad/codegen/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,7 @@ def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool:
if use_tensor_cores and self.opts.has_local and self.reduceop and self.reduceop.op is ReduceOps.SUM and self.opts.device in tensor_cores:
for tc in tensor_cores[self.opts.device]:
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not(self.reduceop.src[0].op in [UnaryOps.CAST, UnaryOps.BITCAST] and self.reduceop.src[0].arg[0] == tc.dtype_out): continue
if has_cast and not(self.reduceop.src[0].op in [UnaryOps.CAST, UnaryOps.BITCAST] and self.reduceop.src[0].arg == tc.dtype_out): continue

mul_op = self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
if mul_op.op is not BinaryOps.MUL: continue
Expand All @@ -352,7 +352,7 @@ def buf_index(src: LazyOp) -> Optional[int]:
# TODO: apply tc even if the sources are not from LOAD
if src.op is BufferOps.LOAD and src.arg.dtype == tc.dtype_in: return self.bufs.index(cast(MemBuffer, src.arg))
try:
if opt_level >= 1 and src.op in [UnaryOps.CAST, UnaryOps.BITCAST] and src.arg[0] == tc.dtype_in:
if opt_level >= 1 and src.op in [UnaryOps.CAST, UnaryOps.BITCAST] and src.arg == tc.dtype_in:
return self.bufs.index(cast(MemBuffer, src.src[0].arg))
except ValueError: return None
return None
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,7 @@ def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_b
if x in cache: return cache[x]
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return [self.uops.add(UOps.BITCAST if x.op is UnaryOps.BITCAST else UOps.CAST, \
self.get_base_dtype(x.arg[0]), (u,), x.arg[0]) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
self.get_base_dtype(x.arg), (u,), x.arg) for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
Expand Down
5 changes: 3 additions & 2 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def _recurse_lb(buf:LazyBuffer, realizes:Dict[LazyBuffer, None], allbufs:Dict[La
if len(buf.st.views) == 1 and buf.st.views[-1].mask and all_int(buf.base.st.shape) and \
prod(buf.base.st.shape) >= prod([y-x for x,y in buf.st.views[-1].mask]):
simple_pads.add(buf.base)
elif buf.base.op is UnaryOps.CAST and isinstance(buf.base.srcs[0].dtype, ImageDType) and isinstance(buf.base.arg[0], ImageDType):
elif buf.base.op in [UnaryOps.CAST, UnaryOps.BITCAST] and isinstance(buf.base.srcs[0].dtype, ImageDType) and \
isinstance(buf.base.arg, ImageDType):
pass # don't realize image to image casts. this is part of a larger problem
else:
realizes[buf.base] = None
Expand Down Expand Up @@ -206,7 +207,7 @@ def _graph_schedule(outs:List[LazyBuffer], seen:Set[LazyBuffer]) -> Tuple[Defaul
if not st.contiguous or tr_next.op in ReduceOps: break
tr = tr_next
# don't cast to higher size before store (tr cannot be realized if forced_realize)
if tr.op is UnaryOps.CAST and tr.arg[0].itemsize > tr.srcs[0].dtype.itemsize:
if tr.op in [UnaryOps.CAST, UnaryOps.BITCAST] and tr.arg.itemsize > tr.srcs[0].dtype.itemsize:
tr = tr.srcs[0].base
reduce_for_op[tr] = r
realizes[tr] = None
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def cast(self, dtype:DType, bitcast:bool=False):
if not (new_shape[-1]*self.dtype.itemsize) % dtype.itemsize == 0: raise RuntimeError("unsupported size in bitcast")
new_shape = new_shape[:-1] + ((new_shape[-1]*self.dtype.itemsize) // dtype.itemsize,)
cast_op = UnaryOps.BITCAST if bitcast else UnaryOps.CAST
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, (dtype,), (self,))
return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,))

def is_unrealized_const(self): return self.base.realized is None and self.base.op is LoadOps.CONST and not isinstance(self.base.arg, Variable)
def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views)
Expand Down
4 changes: 3 additions & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg}
@functools.cached_property
def dtype(self) -> DType:
if self.op in BufferOps: return self.arg.dtype
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]: return self.arg[0]
if self.op in [UnaryOps.CAST, UnaryOps.BITCAST]:
print(f"arg={self.arg}")
return self.arg
return dtypes.bool if self.op in {BinaryOps.CMPLT, BinaryOps.CMPEQ} else self.src[-1].dtype

@functools.cached_property
Expand Down

0 comments on commit 99157dd

Please sign in to comment.