A training framework for running experiments in computer vision that's based on TensorFlow
This project provides a simple framework to experiment with and develop machine learning models.
More specifically, at Verizon Media we have used this code to perform research on the topic of "visual search" and "metric learning" within the e-commerce domain. With this repository, you can train image retrieval models to find relevant products from a catalog of thousands of products given a single query image.
We have presented this work in a poster session at the BayLearn Symposium on Oct 11, 2018. To learn more about our work, please read our technical report "Learning embeddings for Product Visual Search with Triplet Loss and Online Sampling".
The training code in this repository aims to provide a flexible framework for experimenting with low-level neural network details. While high-level APIs such as tf.estimator and tf.keras provide a quick and simple way of training standard convolutional neural networks, they abstract away many details of the training loop that could be interesting to examine within a research context.
The goal of this code base is to give researchers finer-grain controls over these details while still providing "just enough" infrastructure to track, configure, and reproduce experiments in a collaborative setting.
A few benefits of using this code for training:
- Simple Model definition API - All models can be created by defining a
model.config
dictionary object and the following two functions:model.datasets()
andmodel.build(inputs, train_phase)
. - Debug mode - Use the
--debug
flag when calling train.py to jump into an IPython session immediately after model setup has completed. - Resumable models - You can interrupt and restart the training job without worry. The trainer will pickup from the last saved checkpoint. If you want to start fresh, use the
--fresh
flag to delete the existing checkpoint and start from scratch. - Checkpoint on exit - If your job exits for any reason (including if you manually interrupt training), the model will be saved to a checkpoint before exiting.
- Config serialization - All training runs will save the
config
hyper-parameter dictionary as both a human readable text as well as a machine readable format for later reference. - Transparent/Customizable training loop - The training loop is easy to understand to make it simple to use. You can customize your training loop by copying
trainers/default.py
into a new file. - Hyperparameter tuning - Provides the possibility for automated hyper-parameter tuning since
config
is nothing more than a python dictionary that may be manipulated at runtime. - Multi-GPU Training - Train on multiple GPUs at a time.
This code currently runs in Python 2.7.11, so make sure you have this version running. We have plans to migrate to Python 3 in the future.
Before using this library, you will need to have a few prerequisates installed via pip:
pip install numpy tensorflow tensorflow-hub
Then, simply clone this repo.
git clone https://github.com/yahoo/prototrain && cd prototrain
You can start a new experiment by calling the following code from the command line.
export CUDA_VISIBLE_DEVICES=1,2
python train.py -m models.simple
This will train the model in the models/simple
folder. You
The following are configuration flags shared across all models.
dtype | setting | description |
---|---|---|
str |
trainer.experiment_base_dir |
The directory path to store all experiments |
str |
trainer.experiment_name |
The experiment name. This, combined with experiment_base_dir , will be where all experiment logs, checkpoints, and summaries will be saved. The suggested format for this string is "exp-<initials>-<descriptive_name>-<run_number>" like the following example: "exp-hn-evergreen-pascal-voc80" |
int |
trainer.validation_step_interval |
Determines interval of steps between feeding a batch of validation data through session |
int |
trainer.summary_save_interval |
Determines interval of steps between saving tensorboard summaries |
int |
trainer.print_interval |
Determines interval of steps between printing loss |
int |
trainer.checkpoint_save_interval |
Determines interval of steps between saving checkpoints |
int |
trainer.max_checkpoints_to_keep |
The number of checkpoints to save at any given time |
str |
trainer.random_seed |
A random seed to initialize tensorflow and numpy with |
dtype | setting | description |
---|---|---|
str |
model.weights_load_path |
A path to a pretrained checkpoint that you'd like to load for fine tuning |
int |
model.batch_size |
The full batch size to use for training |
int |
model.num_steps |
The number of steps to run trianing |
func |
model.trainable_variables.filter_fn |
A function of the form fn(v) -> Bool that returns True if variable v should be updated via gradient descent during training |
str |
model.trainable_variables.filter_regex |
A regex that returns a match if a variable should be updated during training |
func |
model.weight_decay.filter_fn |
A function of the form fn(v) -> Bool that returns True if variable v should be considered when computing the weight decay loss |
str |
model.weight_decay.filter_regex |
A regex that returns a match if a variable should be considered when computing the weight decay loss |
bool |
model.clip_gradients |
Turn on gradient clipping during training (default: False) |
float |
model.clip_gradients.minval |
Minimum gradient value (default: -1.0) |
float |
model.clip_gradients.maxval |
Maximum gradient value (default: 1.0) |
dtype | setting | description |
---|---|---|
str |
model.optimizer.method |
Choose one of the following optimizers msgd (Momentum SGD), adam , rmsprop , sgd . Based upon this value, additional settings apply below. |
Momentum SGD settings:
dtype | setting | description |
---|---|---|
float |
model.optimizer.msgd.momentum |
Sets momentum when using msgd optimizer |
bool |
model.optimizer.msgd.use_nesterov |
Uses nesterov accelerated gradients if True |
ADAM settings:
dtype | setting | description |
---|---|---|
float |
model.optimizer.adam.beta1 |
ADAM beta1 |
float |
model.optimizer.adam.beta2 |
ADAM beta2 |
float |
model.optimizer.adam.epsilon |
ADAM epsilon |
RMSProp settings:
dtype | setting | description |
---|---|---|
float |
model.optimizer.rmsprop.decay |
RMSProp decay rate |
float |
model.optimizer.rmsprop.momentum |
RMSProp momentum |
float |
model.optimizer.rmsprop.epsilon |
RMSProp epsilon |
dtype | setting | description |
---|---|---|
str |
model.lr_schedule.method |
One of the following [step , exp , constant , custom ]. Based upon this value, additional settings apply below. |
Piecewise learning rate schedule:
dtype | setting | description |
---|---|---|
list(int) |
model.lr_schedule.step.boundaries |
A list of step boundary values that determine when to set the learning rate to the next learning rate value. |
list(float) |
model.lr_schedule.step.values |
List of floating point values for the learning rate. This list should have len(boundaries) + 1 entries, where the first entry is the starting learning rate. |
Constant learning rate schedule:
dtype | setting | description |
---|---|---|
float |
model.lr_schedule.constant.value |
A constant value to be used for the learning rate throughout the entire training job. |
Exponential decay learning rate:
dtype | setting | description |
---|---|---|
float |
model.lr_schedule.exp.value |
Initial value for learning rate |
float |
model.lr_schedule.exp.decay_rate |
A value to decay learning rate by |
float |
model.lr_schedule.exp.decay_steps |
Number of steps to run before decaying the learning rate |
Custom learning rate schedule:
dtype | setting | description |
---|---|---|
func |
model.lr_schedule.custom.fn |
A function of the form fn(step) -> tf.float that returns a float value that represents the learning rate at the provided step . |
Please refer to the contributing.md file for information about how to get involved. We welcome issues, questions, and pull requests. Pull Requests are welcome.
- Huy Nguyen
- Eric Dodds: eric.mcvoy.dodds@verizonmedia.com
We thank the following people for their contributions, testing, and feedback on this code:
- Jack Culpepper
- Simao Herdade
- Kofi Boakye
- Andrew Kae
- Tobias Baumgartner
- Armin Kappeler
- Pierre Garrigues
This project is licensed under the terms of the Apache 2.0 open source license. Please refer to LICENSE for the full terms.