Skip to content

Commit

Permalink
fix commavq benchmark (#4712)
Browse files Browse the repository at this point in the history
* fix _slice and assert explicit device

* with _slice
  • Loading branch information
Qazalin committed May 24, 2024
1 parent 8425506 commit c170ddc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion extra/onnx_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -650,7 +650,7 @@ def Attention(x:Tensor, weights, bias:Optional[Tensor]=None, mask_index:Optional
if unidirectional: # gpt-style
assert hidden_size == v_hidden_size
xqkv = x.linear(weights, bias)
xq, xk, xv = [xqkv.slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
xq, xk, xv = [xqkv._slice([None, None, (i*hidden_size, (i+1)*hidden_size)]) for i in range(3)]
else: # bert-style
wq, wk, wv = weights[:,:hidden_size], weights[:,hidden_size:hidden_size+v_hidden_size], weights[:,hidden_size+v_hidden_size:]
bq, bk, bv = (bias[:hidden_size], bias[hidden_size:hidden_size+v_hidden_size], bias[hidden_size+v_hidden_size]) if bias is not None else None
Expand Down
13 changes: 7 additions & 6 deletions test/external/external_model_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import csv, pathlib, time, numpy as np
from os import getenv
from tinygrad.device import CompileError
import torch
torch.set_num_threads(1)
import onnx
Expand Down Expand Up @@ -60,8 +61,8 @@ def benchmark_model(m, devices, validate_outs=False):

# print input names
if DEBUG >= 2: print([inp.name for inp in onnx_model.graph.input if inp.name not in excluded])
try:
for device in devices:
for device in devices:
try:
Device.DEFAULT = device
inputs = {k:Tensor(inp) for k,inp in np_inputs.items()}
tinygrad_model = get_run_onnx(onnx_model)
Expand All @@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
del inputs, tinygrad_model, tinygrad_jitted_model
except Exception as e:
# model crashed
print(f"{m} crashed on {device} with: {e}")
return
except CompileError as e:
# METAL fails with buffer count limit
if m == "dm" and device == "METAL": return
raise e

# convert model to torch
try:
Expand Down

0 comments on commit c170ddc

Please sign in to comment.