diff --git a/examples/mnist/estimator/mnist_pipeline.py b/examples/mnist/estimator/mnist_pipeline.py index 0939df1f..87a2c256 100644 --- a/examples/mnist/estimator/mnist_pipeline.py +++ b/examples/mnist/estimator/mnist_pipeline.py @@ -181,6 +181,7 @@ def parse(ln): model = TFModel(args) \ .setInputMapping({'image': 'features'}) \ .setOutputMapping({'logits': 'prediction'}) \ + .setSignatureDefKey('serving_default') \ .setExportDir(args.export_dir) \ .setBatchSize(args.batch_size) diff --git a/examples/mnist/keras/mnist_pipeline.py b/examples/mnist/keras/mnist_pipeline.py index 00365070..f20dd50f 100644 --- a/examples/mnist/keras/mnist_pipeline.py +++ b/examples/mnist/keras/mnist_pipeline.py @@ -134,6 +134,7 @@ def parse(ln): model = TFModel(args) \ .setInputMapping({'image': 'conv2d_input'}) \ .setOutputMapping({'dense_1': 'prediction'}) \ + .setSignatureDefKey('serving_default') \ .setExportDir(args.export_dir) \ .setBatchSize(args.batch_size) diff --git a/tensorflowonspark/TFNode.py b/tensorflowonspark/TFNode.py index 85a3f812..6bbed0bb 100644 --- a/tensorflowonspark/TFNode.py +++ b/tensorflowonspark/TFNode.py @@ -231,6 +231,9 @@ def __init__(self, mgr, train_mode=True, qname_in='input', qname_out='output', i self.done_feeding = False self.input_tensors = [tensor for col, tensor in sorted(input_mapping.items())] if input_mapping is not None else None + self.queue_in = mgr.get_queue(qname_in) + self.queue_out = mgr.get_queue(qname_out) + def next_batch(self, batch_size): """Gets a batch of items from the input RDD. @@ -249,34 +252,33 @@ def next_batch(self, batch_size): Returns: A batch of items or a dictionary of tensors. """ - logger.debug("next_batch() invoked") - queue = self.mgr.get_queue(self.qname_in) tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors} count = 0 + queue_in = self.queue_in + no_input_tensors = self.input_tensors is None while count < batch_size: - item = queue.get(block=True) + item = queue_in.get(block=True) if item is None: # End of Feed logger.info("next_batch() got None") - queue.task_done() + queue_in.task_done() self.done_feeding = True break elif type(item) is marker.EndPartition: # End of Partition logger.info("next_batch() got EndPartition") - queue.task_done() + queue_in.task_done() if not self.train_mode and count > 0: break else: # Normal item - if self.input_tensors is None: + if no_input_tensors: tensors.append(item) else: for i in range(len(self.input_tensors)): tensors[self.input_tensors[i]].append(item[i]) count += 1 - queue.task_done() - logger.debug("next_batch() returning {0} items".format(count)) + queue_in.task_done() return tensors def should_stop(self): @@ -292,11 +294,9 @@ def batch_results(self, results): Args: :results: array of output data for the equivalent batch of input data. """ - logger.debug("batch_results() invoked") - queue = self.mgr.get_queue(self.qname_out) + queue = self.queue_out for item in results: queue.put(item, block=True) - logger.debug("batch_results() returning data") def terminate(self): """Terminate data feeding early.