Skip to content

Commit

Permalink
RECIP(-0.0) should be -inf (#5024)
Browse files Browse the repository at this point in the history
* RECIP(-0.0) should be -inf

added test_dtype_alu for PYTHON backend

* catcht that

* fix those two
  • Loading branch information
chenyuxyz committed Jun 18, 2024
1 parent 66760ae commit acaf9a4
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
PYTHONPATH=. DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 ./test/test_linearizer.py TestLinearizer.test_tensor_cores
- name: Test dtype with Python emulator
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 test/test_dtype.py
run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest test/test_dtype.py test/test_dtype_alu.py
- name: Test ops with Python emulator
run: DEBUG=2 PYTHON=1 python3 -m pytest test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int)" --durations=20
- name: Test uops with Python emulator
Expand Down
6 changes: 5 additions & 1 deletion test/test_dtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,11 @@ def test_resulting_and_init_dtypes_match(self):
dtypes = list(map(np.dtype, ["bool", "uint8", "int8", "int16", "int32", "int64", "float32", "float64"]))
data = [1., 2., 0., 0.5, -1.5, 5.25]
for dt in dtypes:
arr = np.asarray(data, dtype=dt)
try:
arr = np.asarray(data, dtype=dt)
except OverflowError:
# TODO: this happens with numpy 2.0, update with proper behavior
continue
tin = Tensor(arr).numpy()
tor = torch.as_tensor(arr).detach().numpy()
assert dt == tin.dtype == tor.dtype, f"dtype mismatch: expected={dt} | tinygrad={tin.dtype} | torch={tor.dtype}"
Expand Down
1 change: 1 addition & 0 deletions test/test_dtype_alu.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def test_int32_midcast_float(self, a, b, c, op1, op2): universal_test_midcast(a,
@given(strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
strat.floats(width=32, min_value=0, max_value=10.0) if skip_overflow else ht.float32,
ht.int32, strat.sampled_from(binary_operations), strat.sampled_from(integer_binary_operations))
@unittest.skipIf(Device.DEFAULT == "PYTHON", "TODO: fix cast inf to int32 in PYTHON")
def test_float_midcast_int32(self, a, b, c, op1, op2): universal_test_midcast(a, b, c, op1, op2, dtypes.float32, dtypes.int32)

@unittest.skip("broken. TODO: fix it")
Expand Down
19 changes: 10 additions & 9 deletions test/test_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,15 +309,16 @@ def test_tensor_list_special_values(self):
data = data + [-x for x in data]
np.testing.assert_allclose(Tensor(data, dtype=dtypes.float16).numpy(), np.array(data, dtype=np.float16))

# uint32
data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
data = data + [-x for x in data]
np.testing.assert_allclose(Tensor(data, dtype=dtypes.uint32).numpy(), np.array(data, dtype=np.uint32))

# int32
data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
data = data + [-x for x in data]
np.testing.assert_allclose(Tensor(data, dtype=dtypes.int32).numpy(), np.array(data, dtype=np.int32))
# TODO: numpy changed this behavior in 2.0
# # uint32
# data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
# data = data + [-x for x in data]
# np.testing.assert_allclose(Tensor(data, dtype=dtypes.uint32).numpy(), np.array(data, dtype=np.uint32))

# # int32
# data = [1 << 33, 1 << 32, 1 << 32 - 1, 1]
# data = data + [-x for x in data]
# np.testing.assert_allclose(Tensor(data, dtype=dtypes.int32).numpy(), np.array(data, dtype=np.int32))

def test_tensor_bytes(self):
data = b"abc123"
Expand Down
2 changes: 1 addition & 1 deletion test/unit/test_disk_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def test_bitcasts_on_disk(self):
_test_bitcasted(t, dtypes.uint32, 0)
# pi in float16 stored via int16
t.bitcast(dtypes.uint16).assign(Tensor.full((128, 64), 0x4248, dtype=dtypes.uint16)).realize()
_test_bitcasted(t, dtypes.float16, 3.141)
_test_bitcasted(t, dtypes.float16, 3.140625)
_test_bitcasted(t, dtypes.float32, 50.064727)
_test_bitcasted(t, dtypes.uint16, 0x4248)
_test_bitcasted(t, dtypes.uint32, 0x42484248)
Expand Down
4 changes: 2 additions & 2 deletions tinygrad/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def wfxn(*args):
UnaryOps.EXP2: hook_overflow(math.inf, lambda x: math.exp(x*math.log(2))),
UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan,
UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan,
UnaryOps.RECIP: lambda x: 1/x if x != 0 else float('inf'),
UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x),
UnaryOps.NEG: lambda x: (not x) if isinstance(x, bool) else -x,
BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift,
BinaryOps.MUL: operator.mul, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub,
Expand All @@ -133,7 +133,7 @@ def truncate_fp16(x):
x = float(x)
struct.pack("@e", x)
return x
except OverflowError: return x * math.inf
except OverflowError: return math.copysign(math.inf, x)

truncate: Dict[DType, Callable] = {dtypes.bool: bool,
# TODO: bfloat16
Expand Down

0 comments on commit acaf9a4

Please sign in to comment.