From 962b44cc9fe736fdbf6ede0b09942e1f60cdca93 Mon Sep 17 00:00:00 2001 From: timmy Date: Mon, 6 May 2024 15:59:06 -0400 Subject: [PATCH 1/4] allow endifs to be inserted before the end of the graph --- test/test_uop_graph.py | 35 +++++++++++++++++++++++++++++++++++ tinygrad/codegen/uops.py | 2 +- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 0996e8706829..f0e6d81518e1 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -46,5 +46,40 @@ def test_const_cast(self): self.assertEqual(out.uop, UOps.CONST) self.assertEqual(out.arg, 0) + def test_early_endif(self): + g = UOpGraph() + g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=True),), cachable=False) + g.add(UOps.CONST, dtypes.int, arg=0) + g.add_ends() + self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 1, "UOpGraph.add_ends() should not add any extra ENDIFs") + self.assertEqual(g.uops[-1].uop, UOps.ENDIF, "UOpGraph.add_ends() should add ENDIF to the end of the graph") + + g = UOpGraph() + if0 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=True),), cachable=False) + before_endif = g.add(UOps.CONST, dtypes.int, arg=0) + endif = g.add(UOps.ENDIF, vin=(if0,), cachable=False) + after_endif = g.add(UOps.CONST, dtypes.int, arg=1) + g.add_ends() + self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 1, "UOpGraph.add_ends() should not add any extra ENDIFs") + self.assertLess(g.uops.index(before_endif), g.uops.index(endif), "Early ENDIF should stay at it's place in the graph") + self.assertLess(g.uops.index(endif), g.uops.index(after_endif), "Early ENDIF should stay at it's place in the graph") + + g = UOpGraph() + if0 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=True),), cachable=False) + before_endif = g.add(UOps.CONST, dtypes.int, arg=0) + endif = g.add(UOps.ENDIF, vin=(if0,), cachable=False) + after_endif = g.add(UOps.CONST, dtypes.int, arg=1) + if1 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=False),), cachable=False) + after_if2 = g.add(UOps.CONST, dtypes.int, arg=2) + g.add_ends() + self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 2, "UOpGraph.add_ends() should not add any extra ENDIFs") + self.assertLess(g.uops.index(before_endif), g.uops.index(endif), "Early ENDIF should stay at it's place in the graph") + self.assertLess(g.uops.index(endif), g.uops.index(after_endif), "Early ENDIF should stay at it's place in the graph") + self.assertLess(g.uops.index(after_endif), g.uops.index(if1)) + self.assertLess(g.uops.index(if1), g.uops.index(after_if2)) + self.assertEqual(g.uops[-1].uop, UOps.ENDIF, "UOpGraph.add_ends() should add ENDIF to the end of the graph") + self.assertIn(if1, g.uops[-1].vin, "UOpGraph.add_ends() should add ENDIF for the unclosed IF") + + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index f965a3cf80ea..cb0a0a8e6a20 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -209,7 +209,7 @@ def add_ends(self): # add END of loops after the last thing that (recursively) depends on them insert_before = self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1 self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=insert_before) - elif u.uop is UOps.IF: + elif u.uop is UOps.IF and all(u not in x.vin for x in self.uops if x.uop is UOps.ENDIF): # END any if statements at the end of the uops self.add(UOps.ENDIF, None, (u,), cachable=False) From 18641c99228108b7e52799bab0050d75725b1910 Mon Sep 17 00:00:00 2001 From: timmy Date: Wed, 8 May 2024 08:57:54 -0400 Subject: [PATCH 2/4] find add ENDIF before next BARRIER --- tinygrad/codegen/uops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index cb0a0a8e6a20..0799d3dc70f0 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -209,9 +209,11 @@ def add_ends(self): # add END of loops after the last thing that (recursively) depends on them insert_before = self.uops.index(sorted(list(self.get_recursive_children(u)), key=self.uops.index)[-1])+1 self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=insert_before) - elif u.uop is UOps.IF and all(u not in x.vin for x in self.uops if x.uop is UOps.ENDIF): - # END any if statements at the end of the uops - self.add(UOps.ENDIF, None, (u,), cachable=False) + elif u.uop is UOps.IF: + # add END of if after the barrier of the store of the result + insert_before = ([i for i in range(self.uops.index(u), len(self.uops)-1) if self.uops[i].uop is UOps.BARRIER and \ + self.uops[i+1].uop is UOps.LOAD and all([x.vin[0] in self.uops[i+1].vin for x in self.uops[i].vin])] + [None])[0] + self.add(UOps.ENDIF, None, (u,), cachable=False, insert_before=insert_before) def fix_loop_scope(self, get_recursive_parents:Callable[..., Set[UOp]]): loop_stack: List[List[UOp]] = [[]] From f1bbd716f9b14a59222ad0224bca07f3b7c7a9ee Mon Sep 17 00:00:00 2001 From: timmy Date: Wed, 8 May 2024 09:19:13 -0400 Subject: [PATCH 3/4] removing tests with manual ENDIF + linters --- test/test_uop_graph.py | 27 --------------------------- tinygrad/codegen/uops.py | 6 +++--- 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index f0e6d81518e1..a0a398a0c14f 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -54,32 +54,5 @@ def test_early_endif(self): self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 1, "UOpGraph.add_ends() should not add any extra ENDIFs") self.assertEqual(g.uops[-1].uop, UOps.ENDIF, "UOpGraph.add_ends() should add ENDIF to the end of the graph") - g = UOpGraph() - if0 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=True),), cachable=False) - before_endif = g.add(UOps.CONST, dtypes.int, arg=0) - endif = g.add(UOps.ENDIF, vin=(if0,), cachable=False) - after_endif = g.add(UOps.CONST, dtypes.int, arg=1) - g.add_ends() - self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 1, "UOpGraph.add_ends() should not add any extra ENDIFs") - self.assertLess(g.uops.index(before_endif), g.uops.index(endif), "Early ENDIF should stay at it's place in the graph") - self.assertLess(g.uops.index(endif), g.uops.index(after_endif), "Early ENDIF should stay at it's place in the graph") - - g = UOpGraph() - if0 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=True),), cachable=False) - before_endif = g.add(UOps.CONST, dtypes.int, arg=0) - endif = g.add(UOps.ENDIF, vin=(if0,), cachable=False) - after_endif = g.add(UOps.CONST, dtypes.int, arg=1) - if1 = g.add(UOps.IF, vin=(g.add(UOps.CONST, dtypes.bool, arg=False),), cachable=False) - after_if2 = g.add(UOps.CONST, dtypes.int, arg=2) - g.add_ends() - self.assertEqual(len([x for x in g.uops if x.uop is UOps.ENDIF]), 2, "UOpGraph.add_ends() should not add any extra ENDIFs") - self.assertLess(g.uops.index(before_endif), g.uops.index(endif), "Early ENDIF should stay at it's place in the graph") - self.assertLess(g.uops.index(endif), g.uops.index(after_endif), "Early ENDIF should stay at it's place in the graph") - self.assertLess(g.uops.index(after_endif), g.uops.index(if1)) - self.assertLess(g.uops.index(if1), g.uops.index(after_if2)) - self.assertEqual(g.uops[-1].uop, UOps.ENDIF, "UOpGraph.add_ends() should add ENDIF to the end of the graph") - self.assertIn(if1, g.uops[-1].vin, "UOpGraph.add_ends() should add ENDIF for the unclosed IF") - - if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 0799d3dc70f0..44d5aa50d40c 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -211,9 +211,9 @@ def add_ends(self): self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=insert_before) elif u.uop is UOps.IF: # add END of if after the barrier of the store of the result - insert_before = ([i for i in range(self.uops.index(u), len(self.uops)-1) if self.uops[i].uop is UOps.BARRIER and \ - self.uops[i+1].uop is UOps.LOAD and all([x.vin[0] in self.uops[i+1].vin for x in self.uops[i].vin])] + [None])[0] - self.add(UOps.ENDIF, None, (u,), cachable=False, insert_before=insert_before) + barriers = ([i for i in range(self.uops.index(u), len(self.uops)-1) if self.uops[i].uop is UOps.BARRIER \ + and self.uops[i+1].uop is UOps.LOAD and all([x.vin[0] in self.uops[i+1].vin for x in self.uops[i].vin])] + [None])[0] + self.add(UOps.ENDIF, None, (u,), cachable=False, insert_before=barriers) def fix_loop_scope(self, get_recursive_parents:Callable[..., Set[UOp]]): loop_stack: List[List[UOp]] = [[]] From b288a5c3cec4114480cdb835a8d0ad01aac49519 Mon Sep 17 00:00:00 2001 From: timmy Date: Wed, 8 May 2024 10:11:51 -0400 Subject: [PATCH 4/4] specifically the next barrier aftr the store of the local result --- tinygrad/codegen/uops.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 44d5aa50d40c..a7929a5e83a9 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -193,13 +193,13 @@ def type_verify(self): assert vin[0].dtype == dtypes.bool, f"{arg} selector dtype mismatch {vin[0].dtype=} != {dtypes.bool}" assert dtype == vin[1].dtype == vin[2].dtype, f"{arg} choice dtype mismatch {dtype=} != {vin[1].dtype=} != {vin[2].dtype=}" - def get_recursive_children(self, x:UOp) -> Set[UOp]: + def get_recursive_children(self, x:UOp, with_phi=False) -> Set[UOp]: deps = set([x]) ssize = 0 while ssize != len(deps): ssize = len(deps) for u in self.uops: - if len(deps.intersection([x for x in u.vin if x.uop is not UOps.PHI])): + if len(deps.intersection([x for x in u.vin if with_phi or x.uop is not UOps.PHI])): deps.add(u) return deps @@ -211,8 +211,7 @@ def add_ends(self): self.add(UOps.ENDLOOP, None, (u,), cachable=False, insert_before=insert_before) elif u.uop is UOps.IF: # add END of if after the barrier of the store of the result - barriers = ([i for i in range(self.uops.index(u), len(self.uops)-1) if self.uops[i].uop is UOps.BARRIER \ - and self.uops[i+1].uop is UOps.LOAD and all([x.vin[0] in self.uops[i+1].vin for x in self.uops[i].vin])] + [None])[0] + barriers = (sorted([self.uops.index(x) for x in list(self.get_recursive_children(u, with_phi=True)) if x.uop is UOps.BARRIER]) + [None])[0] self.add(UOps.ENDIF, None, (u,), cachable=False, insert_before=barriers) def fix_loop_scope(self, get_recursive_parents:Callable[..., Set[UOp]]):