Skip to content

Commit

Permalink
refactor [run_process_replay]
Browse files Browse the repository at this point in the history
delete
  • Loading branch information
Qazalin committed May 24, 2024
1 parent 43efba3 commit bcd3c93
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 43 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Setup process replay
if: contains(github.event.head_commit.message, '[run_process_replay]')
run: |
export TRACE_PATH="$HOME/traces/a921f3317f644c208c5bd6dbe1ed813eff4ab315" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV
export TRACE_PATH="$HOME/traces/$GITHUB_SHA" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV
rm -rf "$TRACE_PATH" && touch $TRACE_PATH
echo "SAVE_TRACE=1" >> $GITHUB_ENV
- name: Run Stable Diffusion
Expand Down
73 changes: 31 additions & 42 deletions test/external/replay_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,56 +13,45 @@ class ReplayItem:
ast: Tuple[LazyOp, ...]
lin: Linearizer
src: str
def print(self):
for op in self.ast: print_tree(op)
self.lin.uops.print()
print(self.src)

TRACE_DIR = os.path.join(os.environ["HOME"], "traces")
subprocess.run(["git", "fetch", "origin", "master"], check=True, text=True)
def get_replays(h:str):
sha = subprocess.run(["git", "rev-parse", h], stdout=subprocess.PIPE, check=True, text=True).stdout.strip()
def _get_replays(ref:str) -> List[ReplayItem]:
sha = subprocess.check_output(["git", "rev-parse", ref], text=True).strip()
replay_items: List[ReplayItem] = []
with open(os.path.join(TRACE_DIR, sha), "rb") as f:
while True:
try: trace: Dict[Tuple[LazyOp, ...], Tuple[List[Linearizer], List[str]]] = pickle.load(f)
except EOFError: break
for ast, (lins, srcs) in trace.items():
for lin, src in zip(lins, srcs): replay_items.append(ReplayItem(ast, lin, src))
for ast, (lins, srcs) in trace.items(): replay_items.extend(ReplayItem(ast, lin, src) for lin, src in zip(lins, srcs))
return replay_items

feat = get_replays("origin/master")
master = get_replays("HEAD")
errored = 0
for f,m in tqdm(zip(feat, master)):
try:
assert f.ast == m.ast
except AssertionError:
print("excepted:")
for op in m.ast: print_tree(op)
print("got:")
for op in f.ast: print_tree(op)
errored += 1
assert f.lin.uops._uops is not None and m.lin.uops._uops is not None
def _recursive_assert_uops_equal(u0:UOp, u1:UOp):
# compare non-vin
assert u0.cmp_tuple[:-1] == u1.cmp_tuple[:-1]
# compare vin
assert len(u0.vin) == len(u1.vin)
for v0, v1 in zip(u0.vin, u1.vin): _recursive_assert_uops_equal(v0, v1)
for u0, u1 in zip(f.lin.uops._uops, m.lin.uops._uops):
try: _recursive_assert_uops_equal(u0, u1)
except AssertionError:
print("expected:")
print(u1)
print("got:")
print(u0)
errored += 1
try: assert f.src == m.src
except AssertionError:
print("expected:")
print(f.src)
print("got:")
print(m.src)
errored += 1
def _recursive_assert_uops_equal(u0:UOp, u1:UOp):
assert u0.cmp_tuple[:-1] == u1.cmp_tuple[:-1]
# compare vin
assert len(u0.vin) == len(u1.vin)
for v0, v1 in zip(u0.vin, u1.vin): _recursive_assert_uops_equal(v0, v1)

if errored > 0:
print(colored(f"{errored} processes failed", "red"))
exit(1)
print(colored("process replay tests passed", "green"))
if __name__ == "__main__":
feat = _get_replays("HEAD")
master = _get_replays("origin/master")
for i, (f,m) in tqdm(enumerate(zip(feat, master))):
try:
# *** Scheduler
assert f.ast == m.ast
# *** Linearizer
assert f.lin.uops._uops is not None and m.lin.uops._uops is not None
for u0, u1 in zip(f.lin.uops._uops, m.lin.uops._uops): _recursive_assert_uops_equal(u0, u1)
# *** Renderer
assert f.src == m.src
except AssertionError:
print(f"ReplayItem {i} FAILED")
print("ACTUAL:")
m.print()
print("DESIRED:")
f.print()
print(colored("process replay tests passed", "green"))

0 comments on commit bcd3c93

Please sign in to comment.