diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index d98bf08a90692..7475801af05d4 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -27,6 +27,18 @@ jobs: ln -s ~/tinygrad/weights/bpe_simple_vocab_16e6.txt.gz weights/bpe_simple_vocab_16e6.txt.gz ln -s ~/tinygrad/weights/LLaMA weights/LLaMA ln -s ~/tinygrad/extra/datasets/cifar-10-python.tar.gz extra/datasets/cifar-10-python.tar.gz + - name: Update process replay reference (master only) + if: github.ref == 'refs/heads/master' + run: | + export TRACE_PATH="$HOME/traces/$GITHUB_SHA" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV + rm -rf "$HOME/traces" && mkdir -p "$HOME/traces" + echo "SAVE_TRACE=1" >> $GITHUB_ENV + - name: Setup process replay + if: contains(github.event.head_commit.message, '[run_process_replay]') + run: | + export TRACE_PATH="$HOME/traces/$GITHUB_SHA" && echo "TRACE_PATH=$TRACE_PATH" >> $GITHUB_ENV + rm -rf "$TRACE_PATH" && touch $TRACE_PATH + echo "SAVE_TRACE=1" >> $GITHUB_ENV - name: Run Stable Diffusion run: JIT=2 python3 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd.txt - name: Run model inference benchmark @@ -46,7 +58,7 @@ jobs: JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_jitted.txt - name: Run LLaMA with BEAM - run: JIT=1 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt + run: JIT=1 SAVE_TRACE=0 BEAM=2 CACHELEVEL=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_beam.txt - name: Run quantized LLaMA run: | JIT=1 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing --quantize int8 | tee llama_int8.txt @@ -60,7 +72,7 @@ jobs: - name: Run GPT2 w HALF run: JIT=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt - name: Run GPT2 w HALF/BEAM - run: JIT=1 HALF=1 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt + run: JIT=1 HALF=1 SAVE_TRACE=0 BEAM=2 CACHELEVEL=0 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt - name: Train MNIST run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=97.3 python3 examples/beautiful_mnist.py | tee beautiful_mnist.txt - name: Run 10 CIFAR training steps @@ -71,6 +83,9 @@ jobs: # run: STEPS=10 DEFAULT_FLOAT=BFLOAT16 python3 examples/hlb_cifar10.py | tee train_cifar_bf16.txt - name: Run 10 CIFAR training steps w winograd run: JIT=2 WINO=1 STEPS=10 python3 examples/hlb_cifar10.py | tee train_cifar_wino.txt + - name: Run process replay tests + if: contains(github.event.head_commit.message, '[run_process_replay]') + run: python3 test/external/replay_benchmarks.py - uses: actions/upload-artifact@v4 with: name: Speed (Mac) diff --git a/test/external/replay_benchmarks.py b/test/external/replay_benchmarks.py new file mode 100644 index 0000000000000..71058899b53f1 --- /dev/null +++ b/test/external/replay_benchmarks.py @@ -0,0 +1,67 @@ +from tqdm import tqdm +from dataclasses import dataclass +import pickle, subprocess, os +from typing import Dict, List, Tuple +from tinygrad.codegen.linearizer import Linearizer +from tinygrad.codegen.uops import UOp +from tinygrad.engine.graph import print_tree +from tinygrad.helpers import colored +from tinygrad.ops import LazyOp + +@dataclass(frozen=True) +class ReplayItem: + ast: Tuple[LazyOp, ...] + lin: Linearizer + src: str + +TRACE_DIR = os.path.join(os.environ["HOME"], "traces") +ret = subprocess.run(['ls', '-l', TRACE_DIR], capture_output=True, text=True) +subprocess.run(["git", "fetch", "origin", "master"], check=True, text=True) +def get_replays(h:str): + sha = subprocess.run(["git", "rev-parse", h], stdout=subprocess.PIPE, check=True, text=True).stdout.strip() + replay_items: List[ReplayItem] = [] + with open(os.path.join(TRACE_DIR, sha), "rb") as f: + while True: + try: trace: Dict[Tuple[LazyOp, ...], Tuple[List[Linearizer], List[str]]] = pickle.load(f) + except EOFError: break + for ast, (lins, srcs) in trace.items(): + for lin, src in zip(lins, srcs): replay_items.append(ReplayItem(ast, lin, src)) + return replay_items + +feat = get_replays("origin/master") +master = get_replays("HEAD") +errored = 0 +for f,m in tqdm(zip(feat, master)): + try: + assert f.ast == m.ast + except AssertionError: + print("excepted:") + for op in m.ast: print_tree(op) + print("got:") + for op in f.ast: print_tree(op) + errored += 1 + assert f.lin.uops._uops is not None and m.lin.uops._uops is not None + def _recursive_assert_uops_equal(u0:UOp, u1:UOp): + # compare non-vin + assert u0.cmp_tuple[:-1] == u1.cmp_tuple[:-1] + # compare vin + assert len(u0.vin) == len(u1.vin) + for v0, v1 in zip(u0.vin, u1.vin): _recursive_assert_uops_equal(v0, v1) + for u0, u1 in zip(f.lin.uops._uops, m.lin.uops._uops): + try: _recursive_assert_uops_equal(u0, u1) + except AssertionError: + print("expected:") + print(u1) + print("got:") + print(u0) + errored += 1 + try: assert f.src == m.src + except AssertionError: + print("expected:") + print(f.src) + print("got:") + print(m.src) + errored += 1 + +if errored == 0: print(colored("process replay tests passed", "green")) +else: print(colored(f"{errored} processes failed", "red")) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 411ca1a9644bd..467bd0d9bcc79 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -457,6 +457,9 @@ def to_program(self) -> Program: self.linearize() info = get_lazyop_info(self.ast[0]) src = self.opts.render(to_function_name(self.name), self.uops) + if getenv("SAVE_TRACE"): + from tinygrad.engine.graph import save_trace + save_trace(self.ast, self, src) ops, mem = self.uops.flops_mem() run_count = prod((self.global_size if self.global_size else []) + (self.local_size if self.local_size else [])) # NOTE: we use min here to ignore the indexing FLOPS diff --git a/tinygrad/engine/graph.py b/tinygrad/engine/graph.py index e4a0f5edffa77..4c5d3e786c9e8 100644 --- a/tinygrad/engine/graph.py +++ b/tinygrad/engine/graph.py @@ -1,6 +1,6 @@ -import os, atexit, functools +import os, atexit, functools, pickle from collections import defaultdict -from typing import List, Any, DefaultDict +from typing import Dict, List, Any, DefaultDict from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps, LoadOps, BufferOps, TernaryOps, LazyOp from tinygrad.device import Device from tinygrad.helpers import GRAPHPATH, DEBUG, GlobalCounters, getenv @@ -20,6 +20,18 @@ def print_globalcounters(): f"{' '*10}total: {GlobalCounters.kernel_count:5d} kernels {GlobalCounters.global_ops*1e-9:8.2f} GOPS {GlobalCounters.global_mem*1e-9:8.2f} GB {GlobalCounters.time_sum_s*1e3:8.2f} ms") # noqa: E501 atexit.register(print_globalcounters) +if getenv("SAVE_TRACE"): + # follow an AST down to Linearizer(s) and Program(s) + trace: Dict = {} + def save_trace(ast, lin=None, prg=None): + if ast not in trace: trace[ast] = ([], []) + if lin is not None: trace[ast][0].append(lin) + if prg is not None: trace[ast][1].append(prg) + def _save_to_file(): + print(f"saving {len(trace)} trace items to", fp:=getenv("TRACE_PATH", "/tmp/trace")) + with open(fp, "ab") as f: pickle.dump(trace, f) + atexit.register(_save_to_file) + def save_graph(G, fn, opt=""): print("saving", G, f"to {fn}.svg") nx.drawing.nx_pydot.write_dot(G, f'{fn}.dot') diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f302ef144a43e..3984ca2d6b8ea 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -311,6 +311,9 @@ def _save(): pickle.dump(SCHEDULES, open(fp, "wb")) if len(SCHEDULES) == 0: atexit.register(_save) SCHEDULES.extend((ps.ast for ps in prescheduled.values()) if getenv("CAPTURE_AST") else [(graph, prescheduled)]) + if getenv("SAVE_TRACE"): + from tinygrad.engine.graph import save_trace + for ps in prescheduled.values(): save_trace(ps.ast) # confirm everything was scheduled correctly if not all(degree == 0 for degree in in_degree.values()) or len(prescheduled) != len(schedule): raise RuntimeError(f"cycle detected in graph, prescheduled {len(prescheduled)} but only scheduled {len(schedule)}")