Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/mnist/estimator/mnist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def parse(ln):
model = TFModel(args) \
.setInputMapping({'image': 'features'}) \
.setOutputMapping({'logits': 'prediction'}) \
.setSignatureDefKey('serving_default') \
.setExportDir(args.export_dir) \
.setBatchSize(args.batch_size)

Expand Down
1 change: 1 addition & 0 deletions examples/mnist/keras/mnist_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def parse(ln):
model = TFModel(args) \
.setInputMapping({'image': 'conv2d_input'}) \
.setOutputMapping({'dense_1': 'prediction'}) \
.setSignatureDefKey('serving_default') \
.setExportDir(args.export_dir) \
.setBatchSize(args.batch_size)

Expand Down
22 changes: 11 additions & 11 deletions tensorflowonspark/TFNode.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ def __init__(self, mgr, train_mode=True, qname_in='input', qname_out='output', i
self.done_feeding = False
self.input_tensors = [tensor for col, tensor in sorted(input_mapping.items())] if input_mapping is not None else None

self.queue_in = mgr.get_queue(qname_in)
self.queue_out = mgr.get_queue(qname_out)

def next_batch(self, batch_size):
"""Gets a batch of items from the input RDD.

Expand All @@ -249,34 +252,33 @@ def next_batch(self, batch_size):
Returns:
A batch of items or a dictionary of tensors.
"""
logger.debug("next_batch() invoked")
queue = self.mgr.get_queue(self.qname_in)
tensors = [] if self.input_tensors is None else {tensor: [] for tensor in self.input_tensors}
count = 0
queue_in = self.queue_in
no_input_tensors = self.input_tensors is None
while count < batch_size:
item = queue.get(block=True)
item = queue_in.get(block=True)
if item is None:
# End of Feed
logger.info("next_batch() got None")
queue.task_done()
queue_in.task_done()
self.done_feeding = True
break
elif type(item) is marker.EndPartition:
# End of Partition
logger.info("next_batch() got EndPartition")
queue.task_done()
queue_in.task_done()
if not self.train_mode and count > 0:
break
else:
# Normal item
if self.input_tensors is None:
if no_input_tensors:
tensors.append(item)
else:
for i in range(len(self.input_tensors)):
tensors[self.input_tensors[i]].append(item[i])
count += 1
queue.task_done()
logger.debug("next_batch() returning {0} items".format(count))
queue_in.task_done()
return tensors

def should_stop(self):
Expand All @@ -292,11 +294,9 @@ def batch_results(self, results):
Args:
:results: array of output data for the equivalent batch of input data.
"""
logger.debug("batch_results() invoked")
queue = self.mgr.get_queue(self.qname_out)
queue = self.queue_out
for item in results:
queue.put(item, block=True)
logger.debug("batch_results() returning data")

def terminate(self):
"""Terminate data feeding early.
Expand Down