Skip to content

Commit

Permalink
save_trace [run_process_replay]
Browse files Browse the repository at this point in the history
benchmark
  • Loading branch information
Qazalin committed May 24, 2024
1 parent a921f33 commit 1d41a3e
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 4 deletions.
19 changes: 17 additions & 2 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,18 @@ jobs:
ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz
ln -s ~/tinygrad/weights/LLaMA weights/LLaMA
ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz
- name: Update process replay reference (master only)
if: github.ref == 'refs/heads/master'
run: |
export TRACE_PATH="$HOME/traces/$GITHUB_SHA" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV
rm -rf "$HOME/traces" && mkdir -p "$HOME/traces"
echo "SAVE_TRACE=1" >> $GITHUB_ENV
- name: Setup process replay
if: contains(github.event.head_commit.message, '[run_process_replay]')
run: |
export TRACE_PATH="$HOME/traces/$GITHUB_SHA" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV
rm -rf "$TRACE_PATH" && touch $TRACE_PATH
echo "SAVE_TRACE=1" >> $GITHUB_ENV
- name: Run Stable Diffusion
run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt
- name: Run model inference benchmark
Expand All @@ -46,7 +58,7 @@ jobs:
JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt
- name: Run LLaMA with BEAM
run: JIT=1 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
run: JIT=1 SAVE_TRACE=0 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt
- name: Run quantized LLaMA
run: |
JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt
Expand All @@ -60,7 +72,7 @@ jobs:
- name: Run GPT2 w HALF
run: JIT=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: JIT=1 HALF=1 SAVE_TRACE=0 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
Expand All @@ -71,6 +83,9 @@ jobs:
# run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt
- name: Run 10 CIFAR training steps w winograd
run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt
- name: Run process replay tests
if: contains(github.event.head_commit.message, '[run_process_replay]')
run: python3 test/external/replay_benchmarks.py
- uses: actions/upload-artifact@v4
with:
name: Speed (Mac)
Expand Down
67 changes: 67 additions & 0 deletions test/external/replay_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
from tqdm import tqdm
from dataclasses import dataclass
import pickle, subprocess, os
from typing import Dict, List, Tuple
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.uops import UOp
from tinygrad.engine.graph import print_tree
from tinygrad.helpers import colored
from tinygrad.ops import LazyOp

@dataclass(frozen=True)
class ReplayItem:
ast: Tuple[LazyOp, ...]
lin: Linearizer
src: str

TRACE_DIR = os.path.join(os.environ["HOME"], "traces")
ret = subprocess.run(['ls', '-l', TRACE_DIR], capture_output=True, text=True)
subprocess.run(["git", "fetch", "origin", "master"], check=True, text=True)
def get_replays(h:str):
sha = subprocess.run(["git", "rev-parse", h], stdout=subprocess.PIPE, check=True, text=True).stdout.strip()
replay_items: List[ReplayItem] = []
with open(os.path.join(TRACE_DIR, sha), "rb") as f:
while True:
try: trace: Dict[Tuple[LazyOp, ...], Tuple[List[Linearizer], List[str]]] = pickle.load(f)
except EOFError: break
for ast, (lins, srcs) in trace.items():
for lin, src in zip(lins, srcs): replay_items.append(ReplayItem(ast, lin, src))
return replay_items

feat = get_replays("origin/master")
master = get_replays("HEAD")
errored = 0
for f,m in tqdm(zip(feat, master)):
try:
assert f.ast == m.ast
except AssertionError:
print("excepted:")
for op in m.ast: print_tree(op)
print("got:")
for op in f.ast: print_tree(op)
errored += 1
assert f.lin.uops._uops is not None and m.lin.uops._uops is not None
def _recursive_assert_uops_equal(u0:UOp, u1:UOp):
# compare non-vin
assert u0.cmp_tuple[:-1] == u1.cmp_tuple[:-1]
# compare vin
assert len(u0.vin) == len(u1.vin)
for v0, v1 in zip(u0.vin, u1.vin): _recursive_assert_uops_equal(v0, v1)
for u0, u1 in zip(f.lin.uops._uops, m.lin.uops._uops):
try: _recursive_assert_uops_equal(u0, u1)
except AssertionError:
print("expected:")
print(u1)
print("got:")
print(u0)
errored += 1
try: assert f.src == m.src
except AssertionError:
print("expected:")
print(f.src)
print("got:")
print(m.src)
errored += 1

if errored == 0: print(colored("process replay tests passed", "green"))
else: print(colored(f"{errored} processes failed", "red"))
3 changes: 3 additions & 0 deletions tinygrad/codegen/linearizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,9 @@ def to_program(self) -> Program:
self.linearize()
info = get_lazyop_info(self.ast[0])
src = self.opts.render(to_function_name(self.name), self.uops)
if getenv("SAVE_TRACE"):
from tinygrad.engine.graph import save_trace
save_trace(self.ast, self, src)
ops, mem = self.uops.flops_mem()
run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else []))
# NOTE: we use min here to ignore the indexing FLOPS
Expand Down
16 changes: 14 additions & 2 deletions tinygrad/engine/graph.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os, atexit, functools
import os, atexit, functools, pickle
from collections import defaultdict
from typing import List, Any, DefaultDict
from typing import Dict, List, Any, DefaultDict
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp
from tinygrad.device import Device
from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv
Expand All @@ -20,6 +20,18 @@ def print_globalcounters():
f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501
atexit.register(print_globalcounters)

if getenv("SAVE_TRACE"):
# follow an AST down to Linearizer(s) and Program(s)
trace: Dict = {}
def save_trace(ast, lin=None, prg=None):
if ast not in trace: trace[ast] = ([], [])
if lin is not None: trace[ast][0].append(lin)
if prg is not None: trace[ast][1].append(prg)
def _save_to_file():
print(f"saving {len(trace)} trace items to", fp:=getenv("TRACE_PATH", "/tmp/trace"))
with open(fp, "ab") as f: pickle.dump(trace, f)
atexit.register(_save_to_file)

def save_graph(G, fn, opt=""):
print("saving", G, f"to {fn}.svg")
nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot')
Expand Down
3 changes: 3 additions & 0 deletions tinygrad/engine/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,9 @@ def _save():
pickle.dump(SCHEDULES, open(fp, "wb"))
if len(SCHEDULES) == 0: atexit.register(_save)
SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)])
if getenv("SAVE_TRACE"):
from tinygrad.engine.graph import save_trace
for ps in prescheduled.values(): save_trace(ps.ast)
# confirm everything was scheduled correctly
if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule):
raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")
Expand Down

0 comments on commit 1d41a3e

Please sign in to comment.