Skip to content

Commit

Permalink
bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
geohot committed Mar 6, 2024
1 parent 8d09fdf commit b4b6278
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 16 deletions.
35 changes: 30 additions & 5 deletions test/test_uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.device import CompiledASTRunner, Compiled
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.codegen.uops import exec_alu
from tinygrad.codegen.uops import exec_alu, UOpGraph
from test.test_dtype import is_dtype_supported

def _uops_to_prg(uops):
Expand All @@ -29,7 +29,7 @@ def _test_single_value(vals, op, dts):
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
buf2 = [Buffer(Device.DEFAULT, 1, dtype).copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
prg = _uops_to_prg(uops)
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf]+buf2)
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
Expand All @@ -43,7 +43,7 @@ def _test_single_value_const(vals, op, dts):
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
buf = Buffer(Device.DEFAULT, 1, output_dtype)
prg = _uops_to_prg(uops)
prg = _uops_to_prg(UOpGraph(uops))
prg.exec([buf])
ret = np.empty(1, output_dtype.np)
buf.copyout(ret.data)
Expand Down Expand Up @@ -90,8 +90,6 @@ def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a<b)
def test_where(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float), PtrDType(dtypes.float)))

# TODO: fix this on all the backends
@unittest.skipIf(not isinstance(Device[Device.DEFAULT], Compiled) or getenv('ARM64', False), "only test for compiled backends, broken on some")
class TestNonFloatUOps(TestUOps):
def test_neg_int32(self): self._test_uop_fxn(UnaryOps.NEG, lambda a: -a, (PtrDType(dtypes.int32), ))
def test_add_int32(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: int(a)+int(b), (PtrDType(dtypes.int32), PtrDType(dtypes.int32)))
Expand All @@ -109,6 +107,33 @@ def test_mul_bool(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: bool(a) a
def test_where_float16(self):
self._test_top_fxn(TernaryOps.WHERE, lambda a,b,c: b if a!=0 else c, (PtrDType(dtypes.bool), PtrDType(dtypes.float16), PtrDType(dtypes.float16)))

class TestBoolUOps(TestUOps):
def _test_uop_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
self._equal(f([a], op, (dtypes.bool, )*1), fxn(a))

def _test_bop_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
for b in [False, True]:
self._equal(f([a,b], op, (dtypes.bool, )*2), fxn(a,b))

def _test_top_bool_fxn(self, op, fxn):
for f in [_test_single_value, _test_single_value_const]:
for a in [False, True]:
for b in [False, True]:
for c in [False, True]:
self._equal(f([a,b,c], op, (dtypes.bool, )*3), fxn(a,b,c))

def test_not_bool(self): self._test_uop_bool_fxn(UnaryOps.NEG, lambda a: not a)
def test_add_bool(self): self._test_bop_bool_fxn(BinaryOps.ADD, lambda a,b: a or b)
def test_mul_bool(self): self._test_bop_bool_fxn(BinaryOps.MUL, lambda a,b: a and b)
def test_xor_bool(self): self._test_bop_bool_fxn(BinaryOps.XOR, lambda a,b: a != b)
def test_cmpeq_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPEQ, lambda a,b: a == b)
def test_cmplt_bool(self): self._test_bop_bool_fxn(BinaryOps.CMPLT, lambda a,b: a < b)
def test_where_bool(self): self._test_top_bool_fxn(TernaryOps.WHERE, lambda a,b,c: b if a else c)

class TestExecALU(TestUOps):
def test_sqrt(self):
self.assertEqual(exec_alu(UnaryOps.SQRT, dtypes.int, (0,)), 0)
Expand Down
7 changes: 4 additions & 3 deletions tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def uop_alu_resolve(u:UOp) -> sint:
def phi_resolve_acc(u:UOp) -> UOp: return u if u.uop is UOps.DEFINE_ACC else phi_resolve_acc(u.vin[0])

class UOpGraph:
def __init__(self):
def __init__(self, start_uops:Optional[List[UOp]]=None):
# list of uops
self.uops: List[UOp] = []
self.uops: List[UOp] = [] if start_uops is None else start_uops

# global uop cache
self.saved_exprs: Dict[Tuple, UOp] = dict()
Expand Down Expand Up @@ -88,7 +88,8 @@ def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(),
if arg is BinaryOps.ADD and vin[1].uop is UOps.ALU and vin[1].arg is UnaryOps.NEG:
return self.add(UOps.ALU, dtype, (vin[0], vin[1].vin[0]), BinaryOps.SUB, cachable, insert_before)
# constant folding
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST: return self.add(UOps.CONST, dtype, arg=-vin[0].arg, insert_before=insert_before)
if arg is UnaryOps.NEG and vin[0].uop is UOps.CONST:
return self.add(UOps.CONST, dtype, arg=-vin[0].arg if dtype != dtypes.bool else not vin[0].arg, insert_before=insert_before)
if arg is TernaryOps.WHERE and vin[1] == vin[2]: return vin[1] # a conditional with the same results either way is a noop
if arg is TernaryOps.WHERE and vin[0].uop is UOps.CONST: return vin[1] if vin[0].arg else vin[2]
if arg is BinaryOps.MUL and vin[0].uop is UOps.CONST and vin[1].uop is UOps.CONST and dtype is not None and dtypes.is_float(dtype):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/device.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ def to_program(self, k:Linearizer) -> CompiledASTRunner:
ops, mem = k.uops.flops_mem()
run_count = prod((k.global_size if k.global_size else []) + (k.local_size if k.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops.uops), self, k.global_size, k.local_size,
ret = CompiledASTRunner(k.name, self.compiler.render(to_function_name(k.name), k.uops), self, k.global_size, k.local_size,
k.uops.vars(), min(info.flops, ops * run_count), min(info.mem_estimate, mem * run_count))
return ret

Expand Down
34 changes: 27 additions & 7 deletions tinygrad/renderer/assembly.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Op
from tinygrad.dtype import dtypes, DType, PtrDType, INVERSE_DTYPES_DICT
from tinygrad.codegen.uops import UOpGraph

def float_to_hex(x): return "%02X%02X%02X%02X" % tuple(struct.pack("f",x)[::-1])
def double_to_hex(x): return "%02X%02X%02X%02X%02X%02X%02X%02X" % tuple(struct.pack("d",x)[::-1])
Expand Down Expand Up @@ -35,11 +36,30 @@ def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pre

def render_kernel(self, kernel, function_name, bufs, regs) -> str: raise NotImplementedError()

def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:List[UOp]) -> str:
def uops_to_asm(lang:AssemblyLanguage, function_name:str, uops:UOpGraph) -> str:
local_size: List[int] = []
kernel:List[str] = []
bufs = []

# here we do a pretransform on UOps to fix some shortcomings of PTX
# all uops must be a register
# TODO: uops class should make these rewrites easier
replace = {}
for u in uops:
for o,n in replace.items():
if o in u.vin and u is not new:
u.vin = tuple(n if x == o else x for x in u.vin)
if u.uop is UOps.ALU and u.arg in {BinaryOps.CMPEQ, BinaryOps.CMPLT} and u.vin[0].dtype is dtypes.bool:
if u.arg == BinaryOps.CMPEQ:
u.arg = BinaryOps.XOR
new = uops.add(UOps.ALU, dtypes.bool, (u,), arg=UnaryOps.NEG, insert_before=uops.uops.index(u)+1)
replace[u] = new
if u.arg == BinaryOps.CMPLT:
new = uops.add(UOps.ALU, dtypes.bool, (u.vin[0],), arg=UnaryOps.NEG, insert_before=uops.uops.index(u))
u.vin = (new, u.vin[1])
u.arg = BinaryOps.MUL
#uops.print()

def kk(*s: str): kernel.append("\n".join(s))

c: DefaultDict[str, int] = defaultdict(int)
Expand Down Expand Up @@ -97,10 +117,10 @@ def cast(a:str, dtype:DType, atype:DType, bitcast=False, u=None, pred=False):
elif uop == UOps.ALU:
assert vin[0].dtype is not None
if args == BinaryOps.CMPLT or args == BinaryOps.CMPEQ:
regs = [cast(r[x], dtypes.int16, dtypes.bool) if x.dtype == dtypes.bool else r[x] for x in vin]
dt = dtypes.int16 if vin[0].dtype == dtypes.bool else vin[0].dtype
kk(lang.asm_for_op[args](pred:=ssa(u,'lt','pred'), *regs, dt, lang.types[dt]))
else: kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
# pass in the other dtype here
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], vin[0].dtype, lang.types[vin[0].dtype]))
else:
kk(lang.asm_for_op[args](ssa(u, "alu"), *[r[x] for x in vin], dtype, lang.types[dtype]))
elif uop == UOps.DEFINE_ACC: kk(f"mov.b{lang.types[dtype][1:]} {ssa(u, 'acc')}, {const(args, dtype)};")
elif uop == UOps.SPECIAL:
if args[1][0] == "i": kk(f"mov.u32 %{args[1]}, {lang.gid[args[0]]};", f"mov.u32 {(gdim:=ssa(None,'tmp','u32'))}, {lang.gdim[args[0]]};",
Expand Down Expand Up @@ -158,14 +178,14 @@ class PTXLanguage(AssemblyLanguage):
gdim = [f'%nctaid.{chr(120+i)}' for i in range(3)]
lid = [f'%tid.{chr(120+i)}' for i in range(3)]
asm_for_op = {
UnaryOps.NEG: lambda d,a,dt,name: f"neg.{name} {d}, {a};",
UnaryOps.NEG: lambda d,a,dt,name: f"not.pred {d}, {a};" if name == "pred" else f"neg.{name} {d}, {a};",
UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};",
UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};",
UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};",
BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};",
BinaryOps.SUB: lambda d,a,b,dt,name: f"sub.{name} {d}, {a}, {b};",
BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};",
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.b{name[1:]} {d}, {a}, {b};",
BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};",
BinaryOps.DIV: lambda d,a,b,dt,name: f"div{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a}, {b};",
BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};",
BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};",
Expand Down

0 comments on commit b4b6278

Please sign in to comment.