Skip to content

Commit

Permalink
Merge aa34f87 into 774d8e3
Browse files Browse the repository at this point in the history
  • Loading branch information
msschwartz21 committed Jul 31, 2019
2 parents 774d8e3 + aa34f87 commit 4125c5c
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 13 deletions.
20 changes: 10 additions & 10 deletions redis_consumer/consumers.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,12 +935,12 @@ def _get_tracker(self, redis_hash, hvalues, raw, segmented):
features = {'appearance', 'distance', 'neighborhood', 'regionprop'}
tracker = tracking.cell_tracker(raw, segmented,
tracking_model,
max_distance=50,
track_length=5,
division=0.5,
birth=0.9,
death=0.9,
neighborhood_scale_size=30,
max_distance=settings.MAX_DISTANCE,
track_length=settings.TRACK_LENGTH,
division=settings.DIVISION,
birth=settings.BIRTH,
death=settings.DEATH,
neighborhood_scale_size=settings.NEIGHBORHOOD_SCALE_SIZE,
features=features)

self.logger.debug('Created tracker!')
Expand Down Expand Up @@ -1007,10 +1007,10 @@ def _load_data(self, hvalues, subdir, fname):
'identity_upload': self.hostname,
'input_file_name': upload_file_name,
'original_name': segment_fname,
'model_name': 'HeLaS3watershed',
'model_version': 2,
'postprocess_function': 'watershed',
'cuts': 0,
'model_name': settings.MODEL_NAME,
'model_version': settings.MODEL_VERSION,
'postprocess_function': settings.POSTPROCESS_FUNCTION,
'cuts': settings.CUTS,
'status': 'new',
'created_at': current_timestamp,
'updated_at': current_timestamp,
Expand Down
8 changes: 6 additions & 2 deletions redis_consumer/grpc_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class GrpcClient(object):
Arguments:
host: string, the hostname and port of the server (`localhost:8080`)
"""

def __init__(self, host):
self.logger = logging.getLogger(self.__class__.__name__)
self.host = host
Expand Down Expand Up @@ -87,6 +88,7 @@ class PredictClient(GrpcClient):
model_name: string, name of model served by tensorflow-serving
model_version: integer, version of the named model
"""

def __init__(self, host, model_name, model_version):
super(PredictClient, self).__init__(host)
self.model_name = model_name
Expand Down Expand Up @@ -157,6 +159,7 @@ class ProcessClient(GrpcClient):
process_type: string, pre or post processing
function_name: string, name of processing function
"""

def __init__(self, host, process_type, function_name):
super(ProcessClient, self).__init__(host)
self.process_type = process_type
Expand Down Expand Up @@ -297,6 +300,7 @@ class TrackingClient(GrpcClient):
model_name: string, name of model served by tensorflow-serving
model_version: integer, version of the named model
"""

def __init__(self, host, redis_hash, model_name, model_version, progress_callback):
super(TrackingClient, self).__init__(host)
self.redis_hash = redis_hash
Expand Down Expand Up @@ -328,7 +332,6 @@ def _predict(self, data, request_timeout=100):
num_preds = data[0].shape[0]

predictions = []
self.logger.debug('Sending %i requests...', num_preds)
for data_i in range(num_preds):
request = PredictRequest()
request.model_spec.name = self.model_name # pylint: disable=E1101
Expand All @@ -341,7 +344,8 @@ def _predict(self, data, request_timeout=100):
tensor_proto = make_tensor_proto(model_input, 'DT_FLOAT')
request.inputs["input{}".format(i)].CopyFrom(tensor_proto)

predictions.append(self._single_request(stub, request))
# Select only last dimension in order to drop batch axis
predictions.append(self._single_request(stub, request)[-1])

self.logger.info('Predicting everything took: %s seconds',
timeit.default_timer() - t)
Expand Down
7 changes: 7 additions & 0 deletions redis_consumer/redis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,13 @@ def sentinel_slaves(self, _):
n = random.randint(1, 4)
return [{'ip': 'slave', 'port': 6379} for i in range(n)]

def sentinel_masters(self):
return {'mymaster': {'ip': 'master', 'port': 6379}}

def sentinel_slaves(self, _):
n = random.randint(1, 4)
return [{'ip': 'slave', 'port': 6379} for i in range(n)]


class TestRedis(object):

Expand Down
14 changes: 14 additions & 0 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,17 @@ def _strip(x):
'retinanet': processing.retinanet_to_label_image,
},
}

# Tracking settings
MODEL_NAME = config('MODEL_NAME', default='HeLaS3watershed')
MODEL_VERSION = config('MODEL_VERSION', default=2)
POSTPROCESS_FUNCTION = config('POSTPROCESS_FUNCTION', default='watershed')
CUTS = config('CUTS', default=0)

# tracking.cell_tracker settings
MAX_DISTANCE = config('MAX_DISTANCE', default=50)
TRACK_LENGTH = config('TRACK_LENGTH', default=5)
DIVISION = config('DIVISION', default=0.9)
BIRTH = config('BIRTH', default=0.95)
DEATH = config('DEATH', default=0.95)
NEIGHBORHOOD_SCALE_SIZE = config('NEIGHBORHOOD_SCALE_SIZE', default=30)
1 change: 0 additions & 1 deletion redis_consumer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,6 @@ def make_tensor_proto(data, dtype):
},
number_to_dtype_value[dtype]: values
}

dict_to_protobuf.dict_to_protobuf(tensor_proto_dict, tensor_proto)

return tensor_proto
Expand Down

0 comments on commit 4125c5c

Please sign in to comment.