Skip to content

Commit

Permalink
Merge pull request #42 from tomwhite/overflow-bug
Browse files Browse the repository at this point in the history
Fix overflow bug exposed by Hypothesis failures
  • Loading branch information
TomAugspurger committed Aug 14, 2020
2 parents 48b0eda + 4d4253e commit 001c01a
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 18 deletions.
12 changes: 1 addition & 11 deletions rechunker/algorithm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
"""Core rechunking algorithm stuff."""
from typing import Sequence, Optional, List, Tuple

try:
from math import prod

numpy_prod = False
except ImportError:
numpy_prod = True
from numpy import prod
from rechunker.compat import prod


def consolidate_chunks(
Expand Down Expand Up @@ -119,10 +113,6 @@ def rechunking_plan(

source_chunk_mem = itemsize * prod(source_chunks)
target_chunk_mem = itemsize * prod(target_chunks)
if numpy_prod:
# Convert to Python type for JSON serialization
source_chunk_mem = source_chunk_mem.item()
target_chunk_mem = target_chunk_mem.item()

if source_chunk_mem > max_mem:
raise ValueError(
Expand Down
13 changes: 13 additions & 0 deletions rechunker/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from functools import reduce
import operator
from typing import Sequence


def prod(iterable: Sequence[int]) -> int:
"""Implementation of `math.prod()` all Python versions."""
try:
from math import prod as mathprod # Python 3.8

return mathprod(iterable)
except ImportError:
return reduce(operator.mul, iterable, 1)
8 changes: 1 addition & 7 deletions tests/test_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,7 @@
#!/usr/bin/env python

"""Tests for `rechunker` package."""
try:
from math import prod
except ImportError:
from numpy import prod

from rechunker.compat import prod

import pytest
from hypothesis import given, assume
Expand Down Expand Up @@ -215,8 +211,6 @@ def shapes_chunks_maxmem_for_ndim(draw):
shapes_chunks_maxmem(ndim=ndim, itemsize=4, max_len=10_000)
)
max_mem = min_mem * 10
# needed to handle overflows
assume(max_mem > 1)
return shape, source_chunks, target_chunks, max_mem, itemsize


Expand Down
10 changes: 10 additions & 0 deletions tests/test_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import numpy as np
from rechunker.compat import prod


def test_prod():
assert prod(()) == 1
assert prod((2,)) == 2
assert prod((2, 3)) == 6
n = np.iinfo(np.int64).max
assert prod((n, 2)) == n * 2

0 comments on commit 001c01a

Please sign in to comment.