Skip to content

Commit

Permalink
Replace filename queue with dataset api
Browse files Browse the repository at this point in the history
  • Loading branch information
tobegit3hub committed May 8, 2018
1 parent 3147931 commit 2ef1f5b
Show file tree
Hide file tree
Showing 2 changed files with 663 additions and 60 deletions.
141 changes: 81 additions & 60 deletions dense_classifier.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import datetime
import logging
import os
import sys
import pprint

import numpy as np
import tensorflow as tf
from sklearn import metrics
Expand All @@ -28,7 +32,7 @@ def define_flags():
"Support tfrecords, csv")
flags.DEFINE_string("train_file", "./data/cancer/cancer_train.csv.tfrecords",
"Train files which supports glob pattern")
flags.DEFINE_string("validate_file",
flags.DEFINE_string("validation_file",
"./data/cancer/cancer_test.csv.tfrecords",
"Validate files which supports glob pattern")
flags.DEFINE_string("inference_data_file", "./data/cancer/cancer_test.csv",
Expand Down Expand Up @@ -60,6 +64,13 @@ def define_flags():
flags.DEFINE_string("model_path", "./model/", "Path of the model")
flags.DEFINE_integer("model_version", 1, "Version of the model")
FLAGS = flags.FLAGS

# Print flags
parameter_value_map = {}
for key in FLAGS.__flags.keys():
parameter_value_map[key] = FLAGS.__flags[key].value
pprint.PrettyPrinter().pprint(parameter_value_map)

return FLAGS


Expand All @@ -76,7 +87,7 @@ def assert_flags(FLAGS):
return

logging.error("Get the unsupported parameters, exit now")
exit(1)
sys.exit(1)


def get_optimizer_by_name(optimizer_name, learning_rate):
Expand Down Expand Up @@ -108,20 +119,6 @@ def restore_from_checkpoint(sess, saver, checkpoint):
return False


def read_and_decode_tfrecords(filename_queue):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
examples = tf.parse_single_example(
serialized_example,
features={
"label": tf.FixedLenFeature([], tf.float32),
"features": tf.FixedLenFeature([FLAGS.feature_size], tf.float32),
})
label = examples["label"]
features = examples["features"]
return label, features


def read_and_decode_csv(filename_queue):
# Notice that it supports label in the last column only
reader = tf.TextLineReader()
Expand Down Expand Up @@ -282,10 +279,17 @@ def inference(inputs, input_units, output_units, is_train=True):
return cnn_inference(inputs, input_units, output_units, is_train)


def _parse_tfrecords_function(example_proto):
features = {"features": tf.FixedLenFeature([FLAGS.feature_size], tf.float32),
"label": tf.FixedLenFeature([], tf.float32, default_value=0.0)}
parsed_features = tf.parse_single_example(example_proto, features)
return parsed_features["features"], parsed_features["label"]



logging.basicConfig(level=logging.INFO)
FLAGS = define_flags()
assert_flags(FLAGS)
pprint.PrettyPrinter().pprint(FLAGS.__flags)
if FLAGS.enable_colored_log:
import coloredlogs
coloredlogs.install()
Expand All @@ -301,54 +305,68 @@ def main():
if os.path.exists(FLAGS.output_path) == False:
os.makedirs(FLAGS.output_path)



epoch_number = 50
batch_size = 10
buffer_size = 100


# Construct the dataset op
EPOCH_NUMBER = FLAGS.epoch_number
if EPOCH_NUMBER <= 0:
EPOCH_NUMBER = None

BATCH_CAPACITY = FLAGS.batch_thread_number * FLAGS.batch_size + FLAGS.min_after_dequeue
train_filename_list = [FLAGS.train_file]
validation_filename_list = [FLAGS.validation_file]

train_filename_placeholder = tf.placeholder(tf.string, shape=[None])
validation_filename_placeholder = tf.placeholder(tf.string, shape=[None])

train_dataset = tf.data.TFRecordDataset(train_filename_placeholder)
validation_dataset = tf.data.TFRecordDataset(validation_filename_placeholder)

train_dataset = train_dataset.map(_parse_tfrecords_function).repeat(epoch_number).batch(batch_size).shuffle(buffer_size=buffer_size)
validation_dataset = validation_dataset.map(_parse_tfrecords_function).repeat(epoch_number).batch(batch_size).shuffle(buffer_size=buffer_size)

train_dataset_iterator = train_dataset.make_initializable_iterator()
validation_dataset_iterator = validation_dataset.make_initializable_iterator()

batch_features_op, batch_label_op = train_dataset_iterator.get_next()
validate_batch_features, validate_batch_labels = validation_dataset_iterator.get_next()

batch_label_op = tf.cast(batch_label_op, tf.int32)
validate_batch_labels = tf.cast(validate_batch_labels, tf.int32)


"""
if FLAGS.train_file_format == "tfrecords":
read_and_decode_function = read_and_decode_tfrecords
pass
#read_and_decode_function = read_and_decode_tfrecords
elif FLAGS.train_file_format == "csv":
read_and_decode_function = read_and_decode_csv

train_filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(FLAGS.train_file), num_epochs=EPOCH_NUMBER)
train_label, train_features = read_and_decode_function(train_filename_queue)
batch_labels, batch_features = tf.train.shuffle_batch(
[train_label, train_features],
batch_size=FLAGS.batch_size,
num_threads=FLAGS.batch_thread_number,
capacity=BATCH_CAPACITY,
min_after_dequeue=FLAGS.min_after_dequeue)

validate_filename_queue = tf.train.string_input_producer(
tf.train.match_filenames_once(FLAGS.validate_file),
num_epochs=EPOCH_NUMBER)
validate_label, validate_features = read_and_decode_function(
validate_filename_queue)
validate_batch_labels, validate_batch_features = tf.train.shuffle_batch(
[validate_label, validate_features],
batch_size=FLAGS.validate_batch_size,
num_threads=FLAGS.batch_thread_number,
capacity=BATCH_CAPACITY,
min_after_dequeue=FLAGS.min_after_dequeue)
pass
#read_and_decode_function = read_and_decode_csv
"""

# Define the model
input_units = FLAGS.feature_size
output_units = FLAGS.label_size

logging.info("Use the model: {}, model network: {}".format(
FLAGS.model, FLAGS.dnn_struct))
logits = inference(batch_features, input_units, output_units, True)

#logits = inference(batch_features, input_units, output_units, True)
logits = inference(batch_features_op, input_units, output_units, True)


if FLAGS.scenario == "classification":
batch_labels = tf.to_int64(batch_labels)
batch_label_op = tf.to_int64(batch_label_op)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits, labels=batch_labels)
#logits=logits, labels=batch_labels)
logits=logits, labels=batch_label_op)
loss = tf.reduce_mean(cross_entropy, name="loss")
elif FLAGS.scenario == "regression":
msl = tf.square(logits - batch_labels, name="msl")
msl = tf.square(logits - batch_label_op, name="msl")
loss = tf.reduce_mean(msl, name="loss")

global_step = tf.Variable(0, name="global_step", trainable=False)
Expand All @@ -370,19 +388,19 @@ def main():

# Avoid error when not using acc and auc op
if FLAGS.scenario == "regression":
batch_labels = tf.to_int64(batch_labels)
batch_labels = tf.to_int64(batch_label_op)

# Define accuracy op for train data
train_accuracy_logits = inference(batch_features, input_units, output_units,
train_accuracy_logits = inference(batch_features_op, input_units, output_units,
False)
train_softmax = tf.nn.softmax(train_accuracy_logits)
train_correct_prediction = tf.equal(
tf.argmax(train_softmax, 1), batch_labels)
tf.argmax(train_softmax, 1), batch_label_op)
train_accuracy = tf.reduce_mean(
tf.cast(train_correct_prediction, tf.float32))

# Define auc op for train data
batch_labels = tf.cast(batch_labels, tf.int32)
batch_labels = tf.cast(batch_label_op, tf.int32)
sparse_labels = tf.reshape(batch_labels, [-1, 1])
derived_size = tf.shape(batch_labels)[0]
indices = tf.reshape(tf.range(0, derived_size, 1), [-1, 1])
Expand Down Expand Up @@ -453,15 +471,18 @@ def main():
writer = tf.summary.FileWriter(FLAGS.output_path, sess.graph)
sess.run(init_op)


sess.run(train_dataset_iterator.initializer, feed_dict={train_filename_placeholder: train_filename_list})
sess.run(validation_dataset_iterator.initializer, feed_dict={validation_filename_placeholder: validation_filename_list})

if FLAGS.mode == "train":
# Restore session and start queue runner
restore_from_checkpoint(sess, saver, LATEST_CHECKPOINT)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord, sess=sess)

start_time = datetime.datetime.now()

try:
while not coord.should_stop():
while True:
if FLAGS.enable_benchmark:
sess.run(train_op)
else:
Expand Down Expand Up @@ -494,19 +515,18 @@ def main():
except tf.errors.OutOfRangeError:
if FLAGS.enable_benchmark:
print("Finish training for benchmark")
exit(0)
sys.exit(0)
else:
# Export the model after training
print("Do not export the model yet")
sys.exit(0)


finally:
coord.request_stop()
coord.join(threads)

elif FLAGS.mode == "savedmodel":
if restore_from_checkpoint(sess, saver, LATEST_CHECKPOINT) == False:
logging.error("No checkpoint for exporting model, exit now")
exit(1)
sys.exit(1)

graph_file_name = "graph.pb"
logging.info("Export the graph to: {}".format(FLAGS.model_path))
Expand Down Expand Up @@ -538,7 +558,7 @@ def main():
elif FLAGS.mode == "inference":
if restore_from_checkpoint(sess, saver, LATEST_CHECKPOINT) == False:
logging.error("No checkpoint for inferencing, exit now")
exit(1)
sys.exit(1)

# Load inference test data
inference_result_file_name = FLAGS.inference_result_file
Expand Down Expand Up @@ -578,3 +598,4 @@ def main():

if __name__ == "__main__":
main()

Loading

0 comments on commit 2ef1f5b

Please sign in to comment.