Skip to content

Commit

Permalink
sort vars in jit when building expected input args (#4990)
Browse files Browse the repository at this point in the history
* sort vars in jit when building expected input args

fixed symbolic jit bugs with two variables.

* sort in clanggraph

* space

* one more
  • Loading branch information
chenyuxyz committed Jun 16, 2024
1 parent 71aad18 commit 72c9b22
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 9 deletions.
2 changes: 0 additions & 2 deletions test/test_symbolic_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,6 @@ def f1(a): return a.mean(1).realize()
expected = a.mean(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

@unittest.skip("failed for some")
def test_mean_2d(self):
def f(a): return a.mean().realize()
def f0(a): return a.mean(0).realize()
Expand Down Expand Up @@ -265,7 +264,6 @@ def f1(a): return a.var(1).realize()
expected = a.var(1).numpy()
np.testing.assert_allclose(symbolic, expected, atol=1e-6, rtol=1e-6)

@unittest.skip("failed for some")
def test_var_2d(self):
def f(a): return a.var().realize()
def f0(a): return a.var(0).realize()
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/codegen/uops.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(self, sinks:List[UOp]):
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
def __getitem__(self, index) -> UOp: return self.uops[index]

def vars(self) -> List[Variable]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR]
def vars(self) -> List[Variable]: return sorted([x.arg for x in self.uops if x.uop is UOps.DEFINE_VAR], key=lambda v: v.expr)
def globals(self) -> List[Tuple[int, bool]]: return [x.arg for x in self.uops if x.uop is UOps.DEFINE_GLOBAL]

@property
Expand Down
5 changes: 2 additions & 3 deletions tinygrad/engine/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], va
if ji.prg.p.vars: self.jc_idx_with_updatable_var_vals.append(j)
if (ji.prg.p.global_size and not all_int(ji.prg.p.global_size)) or (ji.prg.p.local_size and not all_int(ji.prg.p.local_size)):
self.jc_idx_with_updatable_launch_dims.append(j)
self.vars = list(var_vals.keys())
self.vars = sorted(var_vals.keys(), key=lambda v: v.expr)
super().__init__(colored(f"<batched {len(self.jit_cache)}>", "cyan"), jit_cache[0].prg.dname.split(":")[0], op_estimate, mem_estimate)

class MultiGraphRunner(GraphRunner): # pylint: disable=abstract-method
Expand Down Expand Up @@ -143,8 +143,7 @@ def __call__(self, *args, **kwargs) -> ReturnType:
assert len(set(input_buffers)) == len(input_buffers), "duplicate inputs to JIT"
var_vals: Dict[Variable, int] = merge_dicts([varvals for _,varvals,_,_ in st_varvals_dtype_device] + \
[dict(v.unbind() for v in itertools.chain(args, kwargs.values()) if isinstance(v, Variable))])
# TODO: var here is not sorted
st_vars_dtype_device = [(x[0], tuple(x[1].keys()), x[2], x[3]) for x in st_varvals_dtype_device]
st_vars_dtype_device = [(x[0], tuple(sorted(x[1].keys(), key=lambda v: v.expr)), x[2], x[3]) for x in st_varvals_dtype_device]
if self.cnt == 0:
# jit ignore
with Context(BEAM=0 if getenv("IGNORE_JIT_FIRST_BEAM") else BEAM.value):
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def lazyops(self) -> List[LazyOp]: return dedup([self] + [item for x in self.src
def vars(self) -> List[Variable]:
extract_vars = [x.arg.st.vars() for x in self.lazyops if x.op in BufferOps]
const_vars = [x.arg.val for x in self.lazyops if x.op is BufferOps.CONST and isinstance(x.arg.val, Variable)]
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda x: str(x.expr))
return sorted(set.union(*extract_vars, set(const_vars)), key=lambda v: v.expr)

# **************** independent FlopCounter ****************

Expand Down
5 changes: 3 additions & 2 deletions tinygrad/runtime/graph/clang.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], va

prgs = '\n'.join(dedup([cast(CompiledRunner, ji.prg).p.src for ji in jit_cache]))
args = [f"{render_dtype(x.dtype)}* arg{i}" for i,x in enumerate(input_rawbuffers)]
args += [f"int {v.expr}" for v in var_vals]
args += sorted([f"int {v.expr}" for v in var_vals])
code = ["void batched("+','.join(args)+") {"]
for ji in jit_cache:
args = []
Expand All @@ -35,4 +35,5 @@ def __init__(self, jit_cache: List[ExecItem], input_rawbuffers: List[Buffer], va
self.clprg = ClangProgram("batched", compiler.compile(prgs+"\n"+"\n".join(code))) # no point in caching the pointers

def __call__(self, rawbufs: List[Buffer], var_vals: Dict[Variable, int], wait=False):
return cpu_time_execution(lambda: self.clprg(*[x._buf for x in rawbufs], *[x for x in var_vals.values()]), enable=wait)
return cpu_time_execution(
lambda: self.clprg(*[x._buf for x in rawbufs], *[x[1] for x in sorted(var_vals.items(), key=lambda x: x[0].expr)]), enable=wait)

0 comments on commit 72c9b22

Please sign in to comment.