Skip to content

Commit

Permalink
Convert a bunch more rules [run_process_replay] (#5007)
Browse files Browse the repository at this point in the history
* Convert a bunch more rules [run_process_replay]

* more rules, narrow down CMPLT rule

* smart linter cut two lines

* nope, the linter is dumb

* make dumb linter shut up

* revert two rules

* Revert "revert two rules"

This reverts commit 585688d.

* fix
  • Loading branch information
uuuvn committed Jun 17, 2024
1 parent c52352b commit f1de8cd
Showing 1 changed file with 18 additions and 19 deletions.
37 changes: 18 additions & 19 deletions tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def cmp_tuple(self):
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
def cast(self, dtype): return UOp(UOps.CAST, dtype, (self,))
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
def name(self, name:Optional[str]): return UOp(UOps.VAR, vin=(self,), arg=name)
def __neg__(self): return UOp.alu(UnaryOps.NEG, self)
def __add__(self, x): return UOp.alu(BinaryOps.ADD, self, ufix(self.dtype, x))
Expand All @@ -67,6 +67,10 @@ def const(dtype:Optional[DType], b:ConstType|Variable):
@staticmethod
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg)
@staticmethod
def load(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.LOAD, dtype, tuple(vin))
@staticmethod
def store(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.STORE, dtype, tuple(vin))
@staticmethod
def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
@staticmethod
def cvar(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.CONST, dtype=dtype).name(name)
Expand Down Expand Up @@ -99,7 +103,7 @@ class UPat:
@staticmethod
def compile(u: UOp, name:Optional[str]=None) -> UPat:
if u.uop is UOps.VAR: return UPat(name=name or u.arg, dtype=u.dtype) if len(u.vin) == 0 else UPat.compile(u.vin[0], name or u.arg)
return UPat(u.uop, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.vin]), name, u.dtype)
return UPat(u.uop, u.arg, (list if u.commutative() else tuple)([UPat.compile(vin) for vin in u.vin]) if u.vin != () else None, name, u.dtype)

T = TypeVar("T")
def __unmatch(m1:Union[T, Set[T]], m2:T) -> bool:
Expand Down Expand Up @@ -178,11 +182,10 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
(UPat(UOps.ALU, BinaryOps.MUL, [UPat(UOps.CONST, name="c1"),
UPat(UOps.UNMUL, vin=[UPat(UOps.CONST, name="c2"), UPat(name="v")])]),
lambda c1,c2,v: v if c1.arg == c2.arg else None),
(UPat(UOps.UNMUL, vin=(UPat(UOps.CONST, name="zero", arg=0), UPat())), lambda zero: zero),
(UPat(UOps.CAST, name="root", vin=(UPat(UOps.UNMUL, name="unmul"),)),
lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
(UOp(UOps.UNMUL, vin=(UOp.const(None, 0).name('zero'), UOp.var())), lambda zero: zero),
(UOp(UOps.UNMUL).name('unmul').cast().name('root'), lambda root,unmul: UOp(UOps.UNMUL, root.dtype, (unmul.vin[0].cast(root.dtype), unmul.vin[1]))),
# max on special can go away (TODO: special should be variable, same thing applies)
(UPat(UOps.ALU, BinaryOps.MAX, [UPat(UOps.CONST, name="c"), UPat(UOps.SPECIAL, name="s")]), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
(UOp.max(UOp.cvar('c'), UOp(UOps.SPECIAL).name('s')), lambda c,s: c if (s.arg[2]-1) <= c.arg else None),
# const rules
(UPat(UOps.GEP, name="root", vin=(UPat(UOps.CONST, name="c"),)), lambda root, c: UOp.const(root.dtype, c.arg)),
(UPat(UOps.CAST, name="root", vin=UPat(UOps.CONST, name="c")), lambda root, c: UOp.const(root.dtype, c.arg)),
Expand All @@ -206,7 +209,7 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
(UOp.const(dtypes.bool, True).lt(UOp.var()), lambda: UOp.const(dtypes.bool, False)),
# a conditional with the same results either way is a noop, also fold const conditionals
(UOp.alu(TernaryOps.WHERE, UOp.var(), UOp.var("val"), UOp.var("val")), lambda val: val),
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('true'), UOp.var('false')), lambda gate, true, false: true if gate.arg else false),
(UOp.alu(TernaryOps.WHERE, UOp.cvar('gate'), UOp.var('c0'), UOp.var('c1')), lambda gate, c0, c1: c0 if gate.arg else c1),
# ** constant folding **
(UPat(UOps.ALU, name="root", vin=UPat(UOps.CONST)), lambda root: UOp.const(root.dtype, exec_alu(root.arg, root.dtype, [x.arg for x in root.vin]))),
# ** self folding **
Expand Down Expand Up @@ -239,28 +242,24 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
((UOp.var("x") // UOp.cvar("c0")) // UOp.cvar("c1"), lambda x,c0,c1: x//UOp.const(x.dtype, exec_alu(BinaryOps.MUL, x.dtype, [c0.arg, c1.arg]))),
# c0 + x < c1 -> x < c1 - c0
((UOp.cvar("c0") + UOp.var("x")).lt(UOp.cvar("c1")),
lambda x,c0,c1: UOp.alu(BinaryOps.CMPLT, x, UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c1.arg, c0.arg])))),
lambda x,c0,c1: UOp.lt(x, UOp.const(x.dtype, exec_alu(BinaryOps.SUB, x.dtype, [c1.arg, c0.arg])))),
# (x+x*c0)-> x*(c0+1)
(UOp.var("x") + UOp.var("x") * UOp.cvar("c0"), lambda x,c0: x*UOp.const(x.dtype, c0.arg+1)),
# TODO: can do the invert of this (flip alt/load) when we fix double ops
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.ALU, TernaryOps.WHERE,
(UPat(name="gate"), UPat(name="alt"), UPat(UOps.LOAD, vin=(UPat(name="buf"), UPat(name="idx"))))))),
lambda buf, idx, gate, alt: UOp(UOps.STORE, None, (buf, idx, alt, gate))),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# store float4/float2 directly (remove CAST/GEP)
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(4))))),
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
(UPat(UOps.STORE, vin=(UPat(name="buf"), UPat(name="idx"), UPat(UOps.CAST, vin=
tuple(UPat(UOps.GEP, i, vin=(UPat(name="val"),)) for i in range(2))))),
lambda buf,idx,val: UOp(UOps.STORE, None, (buf, idx, val))),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))),
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))),
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
# CAST-PHI-GEP -> PHI-CAST
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(2))),
lambda root, val, v0, v1: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1))))),
# NEG/CMPLT -> CMPLT
(UPat(UOps.ALU, BinaryOps.CMPLT, (UPat(UOps.ALU, UnaryOps.NEG, (UPat(name="x"),)), UPat(UOps.CONST, name="c", dtype=dtypes.int))),
lambda c,x: UOp(UOps.ALU, dtypes.bool, (UOp.const(c.dtype, -c.arg), x), BinaryOps.CMPLT)),
(UOp.lt(-UOp.var('x'), UOp.cvar('c', dtypes.int)), lambda c,x: UOp.lt(UOp.const(c.dtype, -c.arg), x)),
# cast NOOP (NOTE: it's str to deal with PtrDType)
(UPat(UOps.CAST, name="root"), lambda root: root.vin[0] if str(root.dtype) == str(root.vin[0].dtype) else None),
])
Expand Down

0 comments on commit f1de8cd

Please sign in to comment.