Skip to content

Commit

Permalink
Merge branch 'master' into hcq_signal_scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
nimlgen committed Jun 18, 2024
2 parents 9d583b7 + e9c6a36 commit 8bfa4e0
Show file tree
Hide file tree
Showing 12 changed files with 69 additions and 20 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ jobs:
- name: Run LLaMA-3 8B BEAM
run: NV=1 JITBEAM=2 CACHELEVEL=0 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: NV=1 CACHELEVEL=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
run: NV=1 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: NV=1 CACHELEVEL=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
run: NV=1 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
# - name: Run LLaMA-2 70B
# run: CUDA=1 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# - name: Run Mixtral 8x7B
Expand Down Expand Up @@ -322,9 +322,9 @@ jobs:
- name: Run LLaMA-3 8B BEAM
run: AMD=1 JITBEAM=2 CACHELEVEL=0 python3 examples/llama3.py --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: AMD=1 CACHELEVEL=0 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
run: AMD=1 python3 examples/llama3.py --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: AMD=1 CACHELEVEL=0 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
run: AMD=1 python3 examples/llama3.py --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark | tee llama3_six_gpu.txt
- name: Run LLaMA-2 70B
run: AMD=1 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest test/test_dtype.py test/test_dtype_alu.py
- name: Test ops with Python emulator
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20
- name: Test uops with Python emulator
Expand Down
18 changes: 16 additions & 2 deletions test/test_dtype.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import unittest, operator, subprocess
import unittest, operator, subprocess, math
import numpy as np
import torch
from typing import Any, List
Expand Down Expand Up @@ -112,7 +112,11 @@ def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data, dtype=dt)
try:
arr = np.asarray(data, dtype=dt)
except OverflowError:
# TODO: this happens with numpy 2.0, update with proper behavior
continue
tin = Tensor(arr).numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
Expand Down Expand Up @@ -544,6 +548,16 @@ def test_sum(self):
assert (Tensor([0, 1], dtype=dtypes.float32)).sum().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.float64)).sum().dtype == dtypes.float64

@unittest.skipUnless(is_dtype_supported(dtypes.float16), "need float16")
def test_sum_acc_dtype(self):
t = Tensor([40000, 40000], dtype=dtypes.float16)
# default float16 sum returns in float16, overflowed in this case
assert t.sum().dtype == dtypes.float16
assert math.isinf(t.sum().numpy().item())
# specifiying acc_dtype and it's not downcasted
assert t.sum(acc_dtype=dtypes.float32).dtype == dtypes.float32
np.testing.assert_allclose(t.sum(acc_dtype=dtypes.float32).numpy(), 80000)

def test_mean(self):
assert (Tensor([0, 1], dtype=dtypes.bool)).mean().dtype == dtypes.float32
assert (Tensor([0, 1], dtype=dtypes.int8)).mean().dtype == dtypes.float32
Expand Down
1 change: 1 addition & 0 deletions test/test_dtype_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a,
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)

@unittest.skip("broken. TODO: fix it")
Expand Down
2 changes: 1 addition & 1 deletion test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ def test_instancenorm_3d(self):
z.sum().backward()
torch_z.sum().backward(retain_graph=True)
np.testing.assert_allclose(x.grad.numpy(), torch_x.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=1e-3, rtol=1e-3)
np.testing.assert_allclose(layer.weight.grad.numpy(), torch_layer.weight.grad.detach().numpy(), atol=2e-3, rtol=1e-3)
np.testing.assert_allclose(layer.bias.grad.numpy(), torch_layer.bias.grad.detach().numpy(), atol=1e-3, rtol=1e-3)

def test_embedding(self):
Expand Down
3 changes: 3 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,6 +474,9 @@ def test_rshift(self):
def test_sin(self):
helper_test_op([(45,65)], lambda x: x.sin())
helper_test_op([()], lambda x: x.sin())
# works on real CUDA but not CUDACPU
if not (getenv("CUDACPU") or (getenv("MOCKGPU") and Device.DEFAULT == "NV")):
helper_test_op(None, lambda x: x.sin(), vals=[[math.nan, math.inf, -math.inf]])
def test_cos(self):
helper_test_op([(45,65)], lambda x: x.cos())
helper_test_op([()], lambda x: x.cos())
Expand Down
20 changes: 19 additions & 1 deletion test/test_tensor.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import torch
import unittest, copy, mmap, random
import unittest, copy, mmap, random, math
from tinygrad import Tensor, Device, dtypes
from tinygrad.helpers import getenv, temp, CI
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
from hypothesis import given, settings, strategies as strat
from test.helpers import is_dtype_supported

settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
settings.load_profile("my_profile")
Expand Down Expand Up @@ -302,6 +303,23 @@ def _generate_data(depth):
data = _generate_data(depth)
np.testing.assert_allclose(Tensor(data).numpy(), np.array(data))

def test_tensor_list_special_values(self):
if is_dtype_supported(dtypes.float16):
data = [math.nan, -math.inf, 65504, 65519, 65519.999, 65520, 65520.1]
data = data + [-x for x in data]
np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data, dtype=np.float16))

# TODO: numpy changed this behavior in 2.0
# # uint32
# data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
# data = data + [-x for x in data]
# np.testing.assert_allclose(Tensor(data, dtype=dtypes.uint32).numpy(), np.array(data, dtype=np.uint32))

# # int32
# data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
# data = data + [-x for x in data]
# np.testing.assert_allclose(Tensor(data, dtype=dtypes.int32).numpy(), np.array(data, dtype=np.int32))

def test_tensor_bytes(self):
data = b"abc123"
t = Tensor(data)
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_disk_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_bitcasts_on_disk(self):
_test_bitcasted(t, dtypes.uint32, 0)
# pi in float16 stored via int16
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
_test_bitcasted(t, dtypes.float16, 3.141)
_test_bitcasted(t, dtypes.float16, 3.140625)
_test_bitcasted(t, dtypes.float32, 50.064727)
_test_bitcasted(t, dtypes.uint16, 0x4248)
_test_bitcasted(t, dtypes.uint32, 0x42484248)
Expand Down
2 changes: 1 addition & 1 deletion tinygrad/engine/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def log_lazybuffer(lb:'LazyBuffer', scheduled=False):
for idx,x in enumerate(lb.srcs):
if nm(x) not in G.nodes: log_lazybuffer(x)
if x.base.realized is None and x.base.op is LoadOps.CONST:
label_append.append(f"\nCONST{idx} {x.base.arg}")
label_append.append(f"\nCONST{idx} {x.base.arg:g}")
else:
G.add_edge(nm(x), nm(lb), color='#a0a0a0')
label = '"' + \
Expand Down
18 changes: 13 additions & 5 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Union, Tuple, Any, List, Dict, Callable
import functools, hashlib, math, operator, ctypes
import functools, hashlib, math, operator, ctypes, struct
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import prod, dedup
Expand Down Expand Up @@ -118,18 +118,26 @@ def wfxn(*args):
python_alu = {
UnaryOps.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan,
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.SIN: math.sin,
UnaryOps.RECIP: lambda x: 1/x if x != 0 else float('inf'),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
BinaryOps.XOR: operator.xor, BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt,
BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x, y: int(x/y) if y != 0 else x*math.inf,
TernaryOps.WHERE: lambda x,y,z: y if x else z}

def truncate_fp16(x):
try:
x = float(x)
struct.pack("@e", x)
return x
except OverflowError: return math.copysign(math.inf, x)

truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: float16 and bfloat16?
dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
# TODO: bfloat16
dtypes.float16: truncate_fp16, dtypes.float32: lambda x: ctypes.c_float(x).value, dtypes.float64: lambda x: ctypes.c_double(x).value,
dtypes.uint8: lambda x: ctypes.c_uint8(x).value, dtypes.uint16: lambda x: ctypes.c_uint16(x).value,
dtypes.uint32: lambda x: ctypes.c_uint32(x).value, dtypes.uint64: lambda x: ctypes.c_uint64(x).value,
dtypes.int8: lambda x: ctypes.c_int8(x).value, dtypes.int16: lambda x: ctypes.c_int16(x).value,
Expand Down
4 changes: 3 additions & 1 deletion tinygrad/runtime/ops_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from tinygrad.helpers import all_same, getenv, flatten
from tinygrad.device import Compiled, Compiler, Allocator
from tinygrad.codegen.uops import UOpGraph, UOps
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu
from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, truncate
from tinygrad.renderer import Renderer
from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer

Expand Down Expand Up @@ -110,6 +110,8 @@ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tup
if dtypes.is_int(dtype):
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
elif dtypes.is_float(dtype):
casted = [truncate.get(dtype, lambda dt: dt)(x) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):
Expand Down
9 changes: 6 additions & 3 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from tinygrad.helpers import IMAGE, DEBUG, WINO, THREEFRY
from tinygrad.lazy import LazyBuffer
from tinygrad.multi import MultiLazyBuffer
from tinygrad.ops import LoadOps
from tinygrad.ops import LoadOps, truncate
from tinygrad.device import Device, Buffer, BufferOptions
from tinygrad.shape.symbolic import sint, Variable, MulNode, SumNode, NumNode, Node
from tinygrad.engine.realize import run_schedule
Expand Down Expand Up @@ -51,7 +51,8 @@ def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer:
else:
ret = LazyBuffer.loadop(LoadOps.EMPTY, get_shape(x), dtype, "PYTHON")
assert dtype.fmt is not None, f"{dtype=} has None fmt"
data = struct.pack(f"@{ret.size}{dtype.fmt}", *fully_flatten(x))
truncate_function = truncate[dtype]
data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)])
# fake realize
ret.buffer.allocate(memoryview(data))
del ret.srcs
Expand Down Expand Up @@ -1285,7 +1286,8 @@ def sum(self, axis=None, keepdim=False, acc_dtype:Optional[DType]=None):
```
"""
ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim)
return ret.cast(self.dtype) if self.dtype in {dtypes.float16, dtypes.bfloat16} else ret
return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret

def max(self, axis=None, keepdim=False):
"""
Returns the maximum value of the tensor along the specified axis or axes.
Expand All @@ -1308,6 +1310,7 @@ def max(self, axis=None, keepdim=False):
```
"""
return self._reduce(F.Max, axis, keepdim)

def min(self, axis=None, keepdim=False):
"""
Returns the minimum value of the tensor along the specified axis or axes.
Expand Down

0 comments on commit 8bfa4e0

Please sign in to comment.