Skip to content

Commit

Permalink
minor refactor overflow handing in python backend (#5015)
Browse files Browse the repository at this point in the history
made it clear that it's only handing int now. need to handle float inf next
  • Loading branch information
chenyuxyz committed Jun 17, 2024
1 parent 1ad3b25 commit 013c73c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tinygrad/runtime/ops_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,18 @@ def __call__(self, *bufs, global_size:Tuple[int,int,int]=(1,1,1), local_size:Tup
del ul[i]
i = loop_ends[i] + 1
continue
elif uop in {UOps.CAST, UOps.BITCAST}:
elif uop in (UOps.CAST, UOps.BITCAST):
if dtype.count > 1: ul[i] = inp
else:
assert dtp[0].fmt and dtype.fmt
pack_format, unpack_format = str(warp_size) + dtp[0].fmt, str(warp_size) + dtype.fmt
if uop is UOps.BITCAST: ul[i] = list(struct.unpack(unpack_format, struct.pack(pack_format, *inp[0])))
else:
casted = [dtypes.as_const(x, dtype) for x in inp[0]]
overflow_adjust = 2**(dtype.itemsize*8 - 1) if (dtypes.is_int(dtype) and not dtypes.is_unsigned(dtype)) else 0
overflow_fixed = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) if dtypes.is_int(dtype) else x for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *overflow_fixed)))
if dtypes.is_int(dtype):
overflow_adjust = 2**(dtype.itemsize*8 - 1) if not dtypes.is_unsigned(dtype) else 0
casted = [((x + overflow_adjust) % 2**(dtype.itemsize*8) - overflow_adjust) for x in casted]
ul[i] = list(struct.unpack(unpack_format, struct.pack(unpack_format, *casted)))
elif uop is UOps.LOAD:
if isinstance(dtp[0], ImageDType):
assert dtype.count == 4
Expand Down

0 comments on commit 013c73c

Please sign in to comment.