Skip to content

Commit

Permalink
allow keyword args in UOp.store [run_process_replay] (#5008)
Browse files Browse the repository at this point in the history
* allow keyword args in UOp.store [run_process_replay]

* same for load

* typing can stay
  • Loading branch information
Qazalin committed Jun 17, 2024
1 parent f1de8cd commit 026c595
Showing 1 changed file with 4 additions and 6 deletions.
10 changes: 4 additions & 6 deletions tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ def const(dtype:Optional[DType], b:ConstType|Variable):
@staticmethod
def alu(arg, *vin:UOp): return UOp(UOps.ALU, dtypes.bool if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} else vin[-1].dtype, vin, arg)
@staticmethod
def load(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.LOAD, dtype, tuple(vin))
def load(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.LOAD, dtype, tuple(vin)+tuple(kwargs.values()))
@staticmethod
def store(*vin: UOp, dtype:Optional[DType]=None): return UOp(UOps.STORE, dtype, tuple(vin))
def store(*vin:UOp, dtype:Optional[DType]=None, **kwargs): return UOp(UOps.STORE, dtype, tuple(vin)+tuple(kwargs.values()))
@staticmethod
def var(name: Optional[str]=None, dtype: Optional[DType]=None): return UOp(UOps.VAR, dtype=dtype, arg=name)
@staticmethod
Expand Down Expand Up @@ -249,10 +249,8 @@ def loop_collapse(loop_start, loop_end, compval, idx, mval, multconst):
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp.alu(TernaryOps.WHERE, UOp.var("gate"), UOp.var("alt"), UOp.load(UOp.var("buf"), UOp.var("idx")))),
lambda buf, idx, gate, alt: UOp.store(buf, idx, alt, gate)),
# store float4/float2 directly (remove CAST/GEP)
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))),
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))),
lambda buf, idx, val: UOp.store(buf, idx, val)), # pylint: disable=unnecessary-lambda
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(4)))), UOp.store),
(UOp.store(UOp.var("buf"), UOp.var("idx"), UOp(UOps.CAST, vin=tuple(UOp(UOps.GEP, arg=i, vin=(UOp.var("val"),)) for i in range(2)))), UOp.store),
# CAST-PHI-GEP -> PHI-CAST
(UPat(UOps.CAST, name="root", vin=tuple(UPat(UOps.PHI, vin=(UPat(UOps.GEP, i, vin=(UPat(name="val"),)), UPat(name=f"v{i}"))) for i in range(4))),
lambda root, val, v0, v1, v2, v3: UOp(UOps.PHI, root.dtype, (val, UOp(UOps.CAST, val.dtype, (v0, v1, v2, v3))))),
Expand Down

0 comments on commit 026c595

Please sign in to comment.