TensorFlow template application for deep learning
Switch branches/tags
Nothing to show
Clone or download
Permalink
Failed to load latest commit information.
android_client Add android client for tensorflow models Feb 2, 2018
cpp_predict_client Update the usage of sparse predict client in cpp May 5, 2017
cpp_predict_server Remove tensorflow_model_server file Nov 10, 2016
data Update a8a label type from float32 to int64 May 10, 2018
distributed Upgrade code for distributed training May 27, 2017
golang_predict_client simplify the creation of golang client predict request Feb 22, 2017
http_service correct the spelling from `cancel` to `cencer` Dec 14, 2017
ios_client Add all ios files Feb 8, 2018
java_predict_client Support shuffle and delete in spark generate csv tfrecords Feb 15, 2017
minimal_model Update document of benchmark minial model Apr 26, 2017
model/1 Add example model May 17, 2018
python_predict_client Fix number of items in predict client Jan 10, 2018
sklearn_exmaples Add sklearn example for cancer dataset Apr 27, 2017
tensorboard_tools Add tensorboard tool to display event files May 25, 2018
trainer Add sklearn example for cancer dataset Apr 27, 2017
.gitignore update Oct 20, 2016
LICENSE Initial commit Jul 18, 2016
README.md Update document of exporting model and resolve #42 Apr 2, 2018
__init__.py Refactor to generated tfrecords format May 8, 2018
dense_classifier.py Support different label type of tfrecords Jul 12, 2018
dense_classifier_use_queue.py Replace filename queue with dataset api May 8, 2018
model.py Support lstm, bidirectional-lstm, gru models May 9, 2018
requirements.txt Update requironments May 24, 2017
setup.py Refactor to generated tfrecords format May 8, 2018
sparse_classifier.py Support different label type of tfrecords Jul 12, 2018
sparse_model.py Refactor sparse model with dataset api May 10, 2018
util.py Support different label type of tfrecords Jul 12, 2018

README.md

Introduction

It is the generic golden program for deep learning with TensorFlow.

Following are the supported features.

Usage

Generate TFRecords

If your data is in CSV format, generate TFRecords like this.

cd ./data/cancer/

./generate_csv_tfrecords.py

If your data is in LIBSVM format, generate TFRecords like this.

cd ./data/a8a/

./generate_libsvm_tfrecord.py

For large dataset, you can use Spark to do that. Please refer to data.

Run Training

You can train with the default configuration.

./dense_classifier.py

./sparse_classifier.py

Using different models or hyperparameters is easy with TensorFlow flags.

./dense_classifier.py --batch_size 1024 --epoch_number 1000 --step_to_validate 10 --optmizier adagrad --model dnn --model_network "128 32 8"

If you use other dataset like iris, no need to modify the code. Just run with parameters to specify the TFRecords files.

./dense_classifier.py --train_file ./data/iris/iris_train.csv.tfrecords --validate_file ./data/iris/iris_test.csv.tfrecords --feature_size 4 --label_size 3  --enable_colored_log

./dense_classifier.py --train_file ./data/iris/iris_train.csv --validate_file ./data/iris/iris_test.csv --feature_size 4 --label_size 3 --input_file_format csv --enable_colored_log

If you want to use CNN model, try this command.

./dense_classifier.py --train_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --validate_file ./data/lung/fa7a21165ae152b13def786e6afc3edf.dcm.csv.tfrecords --feature_size 262144 --label_size 2 --batch_size 2 --validate_batch_size 2 --epoch_number -1 --model cnn

For boston housing dataset.

./dense_classifier.py --train_file ./data/boston_housing/train.csv.tfrecords --validate_file ./data/boston_housing/train.csv.tfrecords --feature_size 13 --label_size 1 --scenario regression  --batch_size 1 --validate_batch_size 1

Export The Model

After training, it will export the model automatically. Or you can export manually.

./dense_classifier.py --mode savedmodel

Validate The Model

If we want to run inference to validate the model, you can run like this.

./dense_classifier.py --mode inference

Use TensorBoard

The program will generate TensorFlow event files automatically.

tensorboard --logdir ./tensorboard/

Then go to http://127.0.0.1:6006 in the browser.

Serving and Predicting

The exported model is compatible with TensorFlow Serving. You can follow the document and run the tensorflow_model_server.

./tensorflow_model_server --port=9000 --model_name=dense --model_base_path=./model/

We have provided some gRPC clients for dense and sparse models, such as Python predict client and Java predict client.

./predict_client.py --host 127.0.0.1 --port 9000 --model_name dense --model_version 1

mvn compile exec:java -Dexec.mainClass="com.tobe.DensePredictClient" -Dexec.args="127.0.0.1 9000 dense 1"

Contribution

This project is widely used for different tasks with dense or sparse data.

If you want to make contirbutions, feel free to open an issue or pull-request.