Skip to content

Commit

Permalink
test detect scale when SCALE_DETECT_ENABLED is True
Browse files Browse the repository at this point in the history
  • Loading branch information
willgraf committed Sep 5, 2019
1 parent e765672 commit c449602
Showing 1 changed file with 31 additions and 3 deletions.
34 changes: 31 additions & 3 deletions redis_consumer/consumers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,15 +408,35 @@ def dummydata(*_, **__):
def test_detect_scale(self):
redis_client = DummyRedis([])
consumer = consumers.TensorFlowServingConsumer(redis_client, None, 'q')
image = _get_image(settings.SCALE_RESHAPE_SIZE * 2,
settings.SCALE_RESHAPE_SIZE * 2)
big_size = settings.SCALE_RESHAPE_SIZE * np.random.randint(2, 9)
image = _get_image(big_size, big_size)

expected = (settings.SCALE_RESHAPE_SIZE / (big_size)) ** 2

settings.SCALE_DETECT_MODEL = 'dummymodel:1'

consumer.grpc_image = lambda *x: np.random.randint(0, 1, shape=(1, 3))
def grpc_image(*_, **__):
sign = -1 if np.random.randint(1, 5) > 2 else 1
return expected + sign * 1e-8 # small differences get averaged out

consumer.grpc_image = grpc_image

settings.SCALE_DETECT_ENABLED = False

scale = consumer.detect_scale(image)
assert scale == 1

settings.SCALE_DETECT_ENABLED = True

scale = consumer.detect_scale(image)
assert isinstance(scale, (float, int))
np.testing.assert_almost_equal(scale, expected)

consumer.grpc_image = grpc_image

scale = consumer.detect_scale(np.expand_dims(image, axis=-1))
assert isinstance(scale, (float, int))
np.testing.assert_almost_equal(scale, expected)


class TestImageFileConsumer(object):
Expand All @@ -438,6 +458,7 @@ def test_is_valid_hash(self):
assert consumer.is_valid_hash('predict:1234567890:file.png') is True

def test__get_processing_function(self):
_funcs = settings.PROCESSING_FUNCTIONS
settings.PROCESSING_FUNCTIONS = {
'valid': {
'valid': lambda x: True
Expand All @@ -456,7 +477,10 @@ def test__get_processing_function(self):
with pytest.raises(ValueError):
consumer._get_processing_function('valid', 'invalid')

settings.PROCESSING_FUNCTIONS = _funcs

def test_process(self):
_funcs = settings.PROCESSING_FUNCTIONS
settings.PROCESSING_FUNCTIONS = {
'valid': {
'valid': lambda x: x
Expand All @@ -469,7 +493,11 @@ def test_process(self):
output = consumer.process(img, 'valid', 'valid')
assert img.shape[1:] == output.shape

settings.PROCESSING_FUNCTIONS = _funcs

def test__consume(self):
settings.LABEL_DETECT_ENABLED = False
settings.SCALE_DETECT_ENABLED = False
prefix = 'predict'
status = 'new'
redis_client = DummyRedis(prefix, status)
Expand Down

0 comments on commit c449602

Please sign in to comment.