Skip to content

Commit

Permalink
update base consumer to take encoder flag and create wraper, update t…
Browse files Browse the repository at this point in the history
…racking consumer
  • Loading branch information
MekWarrior committed Mar 24, 2021
1 parent d63d31b commit 65f39aa
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 3 deletions.
12 changes: 11 additions & 1 deletion redis_consumer/consumers/base_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ def get_model_metadata(self, model_name, model_version):
self.logger.error('Malformed metadata: %s', model_metadata)
raise err

def get_grpc_app(self, model, application_cls, **kwargs):
def get_grpc_app(self, model, application_cls, encoder_req=False, **kwargs):
"""
Create an application from deepcell.applications
with a gRPC model wrapper as a model
Expand All @@ -397,6 +397,16 @@ def get_grpc_app(self, model, application_cls, **kwargs):
model_metadata = self.get_model_metadata(model_name, model_version)
client = self._get_predict_client(model_name, model_version)
model_wrapper = GrpcModelWrapper(client, model_metadata)

if encoder_req:
encoder = kwargs['encoder']
encoder_name, encoder_version = encoder.split(':')
encoder_metadata = self.get_model_metadata(encoder_name, encoder_version)
client = self._get_predict_client(encoder_name, encoder_version)
encoder_wrapper = GrpcModelWrapper(client, encoder_metadata)
del kwargs['encoder']
return application_cls(model_wrapper, encoder_wrapper, **kwargs)

return application_cls(model_wrapper, **kwargs)

def detect_scale(self, image):
Expand Down
2 changes: 2 additions & 0 deletions redis_consumer/consumers/tracking_consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,8 @@ def _consume(self, redis_hash):

# Send data to the model
app = self.get_grpc_app(settings.TRACKING_MODEL, CellTracking,
encoder_req=True,
encoder=settings.NEIGHBORHOOD_ENCODER,
birth=settings.BIRTH,
death=settings.DEATH,
division=settings.DIVISION,
Expand Down
4 changes: 2 additions & 2 deletions redis_consumer/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,15 @@

# Tracking settings
TRACKING_MODEL = config('TRACKING_MODEL', default='TrackingModel:0', cast=str)
NEIGHBORHOOD_ENCODER = config('NEIGHBORHOOD_ENCODER', default='TrackingModelNE:0', cast=str)
DRIFT_CORRECT_ENABLED = config('DRIFT_CORRECT_ENABLED', default=False, cast=bool)

# tracking.cell_tracker settings TODO: can we extract from model_metadata?
MAX_DISTANCE = config('MAX_DISTANCE', default=50, cast=int)
TRACK_LENGTH = config('TRACK_LENGTH', default=9, cast=int)
TRACK_LENGTH = config('TRACK_LENGTH', default=5, cast=int)
DIVISION = config('DIVISION', default=0.9, cast=float)
BIRTH = config('BIRTH', default=0.99, cast=float)
DEATH = config('DEATH', default=0.99, cast=float)
NEIGHBORHOOD_SCALE_SIZE = config('NEIGHBORHOOD_SCALE_SIZE', default=30, cast=int)

# Scale detection settings
SCALE_DETECT_MODEL = config('SCALE_DETECT_MODEL', default='ScaleDetection:1')
Expand Down

0 comments on commit 65f39aa

Please sign in to comment.