Skip to content

rohtash0211/tensornets

Β 
Β 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

TensorNets

High level network definitions with pre-trained weights in TensorFlow (tested with >= 1.1.0).

Guiding principles

  • Applicability. Many people already have their own ML workflows, and want to put a new model on their workflows. TensorNets can be easily plugged together because it is designed as simple functional interfaces without custom classes.
  • Manageability. Models are written in tf.contrib.layers, which is lightweight like PyTorch and Keras, and allows for ease of accessibility to every weight and end-point. Also, it is easy to deploy and expand a collection of pre-processing and pre-trained weights.
  • Readability. With recent TensorFlow APIs, more factoring and less indenting can be possible. For example, all the inception variants are implemented as about 500 lines of code in TensorNets while 2000+ lines in official TensorFlow models.
  • Reproducibility. You can always reproduce the original results with simple APIs including feature extractions. Furthermore, you don't need to care about a version of TensorFlow beacuse compatibilities with various releases of TensorFlow have been checked with Travis.

Installation

You can install TensorNets from PyPI (pip install tensornets) or directly from GitHub (pip install git+https://github.com/taehoonlee/tensornets.git).

A quick example

Each network (see full list) is not a custom class but a function that takes and returns tf.Tensor as its input and output. Here is an example of ResNet50:

import tensorflow as tf
import tensornets as nets

inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
model = nets.ResNet50(inputs)

assert isinstance(model, tf.Tensor)

You can load an example image by using utils.load_img returning a np.ndarray as the NHWC format:

img = nets.utils.load_img('cat.png', target_size=256, crop_size=224)
assert img.shape == (1, 224, 224, 3)

Once your network is created, you can run with regular TensorFlow APIs 😊 because all the networks in TensorNets always return tf.Tensor. Using pre-trained weights and pre-processing are as easy as pretrained() and preprocess() to reproduce the original results:

with tf.Session() as sess:
    img = model.preprocess(img)  # equivalent to img = nets.preprocess(model, img)
    sess.run(model.pretrained())  # equivalent to nets.pretrained(model)
    preds = sess.run(model, {inputs: img})

You can see the most probable classes:

print(nets.utils.decode_predictions(preds, top=2)[0])
[(u'n02124075', u'Egyptian_cat', 0.28067636), (u'n02127052', u'lynx', 0.16826575)]

You can also easily obtain values of intermediate layers with get_middles() and get_outputs():

with tf.Session() as sess:
    img = model.preprocess(img)
    sess.run(model.pretrained())
    middles = sess.run(model.get_middles(), {inputs: img})
    outputs = sess.run(model.get_outputs(), {inputs: img})

model.print_middles()
assert middles[0].shape == (1, 56, 56, 256)
assert middles[-1].shape == (1, 7, 7, 2048)

model.print_outputs()
assert sum(sum((outputs[-1] - preds) ** 2)) < 1e-8

TensorNets enables us to deploy well-known architectures and benchmark those results faster ⚑️. For more information, you can check out the lists of utilities, examples, and architectures.

Object detection example

Each object detection model can be coupled with any network in TensorNets (see performance) and takes two arguments: a placeholder and a function acting as a stem layer. Here is an example of YOLOv2 for PASCAL VOC:

import tensorflow as tf
import tensornets as nets

inputs = tf.placeholder(tf.float32, [None, 416, 416, 3])
model = nets.YOLOv2(inputs, nets.Darknet19)

img = nets.utils.load_img('cat.png')

with tf.Session() as sess:
    sess.run(model.pretrained())
    preds = sess.run(model, {inputs: model.preprocess(img)})
    boxes = model.get_boxes(preds, img.shape[1:3])

Like other models, a detection model also returns tf.Tensor as its output. You can see the bounding box predictions (x1, y1, x2, y2, score) by using model.get_boxes(model_output, original_img_shape) and visualize the results:

from tensornets.datasets import voc
print("%s: %s" % (voc.classnames[7], boxes[7][0]))  # 7 is cat

import numpy as np
import matplotlib.pyplot as plt
box = boxes[7][0]
plt.imshow(img[0].astype(np.uint8))
plt.gca().add_patch(plt.Rectangle(
    (box[0], box[1]), box[2] - box[0], box[3] - box[1],
    fill=False, edgecolor='r', linewidth=2))
plt.show()

More detection examples such as FasterRCNN on VOC2007 are here 😎. Note that:

  • APIs of detection models are slightly different:

    • YOLOv3: sess.run(model.preds, {inputs: img}),
    • YOLOv2: sess.run(model, {inputs: img}),
    • FasterRCNN: sess.run(model, {inputs: img, model.scales: scale}),
  • FasterRCNN requires roi_pooling:

    • git clone https://github.com/deepsense-io/roi-pooling && cd roi-pooling && vi roi_pooling/Makefile and edit according to here,
    • python setup.py install.

Utilities

Besides pretrained() and preprocess(), the output tf.Tensor provides the following useful methods:

  • get_middles(): returns a list of all the representative tf.Tensor end-points,
  • get_outputs(): returns a list of all the tf.Tensor end-points,
  • get_weights(): returns a list of all the tf.Tensor weight matrices,
  • print_middles(): prints all the representative end-points,
  • print_outputs(): prints all the end-points,
  • print_weights(): prints all the weight matrices,
  • print_summary(): prints the numbers of layers, weight matrices, and parameters.
Example outputs of print methods are:
>>> model.print_middles()
Scope: resnet50
conv2/block1/out:0 (?, 56, 56, 256)
conv2/block2/out:0 (?, 56, 56, 256)
conv2/block3/out:0 (?, 56, 56, 256)
conv3/block1/out:0 (?, 28, 28, 512)
conv3/block2/out:0 (?, 28, 28, 512)
conv3/block3/out:0 (?, 28, 28, 512)
conv3/block4/out:0 (?, 28, 28, 512)
conv4/block1/out:0 (?, 14, 14, 1024)
...

>>> model.print_outputs()
Scope: resnet50
conv1/pad:0 (?, 230, 230, 3)
conv1/conv/BiasAdd:0 (?, 112, 112, 64)
conv1/bn/batchnorm/add_1:0 (?, 112, 112, 64)
conv1/relu:0 (?, 112, 112, 64)
pool1/pad:0 (?, 114, 114, 64)
pool1/MaxPool:0 (?, 56, 56, 64)
conv2/block1/0/conv/BiasAdd:0 (?, 56, 56, 256)
conv2/block1/0/bn/batchnorm/add_1:0 (?, 56, 56, 256)
conv2/block1/1/conv/BiasAdd:0 (?, 56, 56, 64)
conv2/block1/1/bn/batchnorm/add_1:0 (?, 56, 56, 64)
conv2/block1/1/relu:0 (?, 56, 56, 64)
...

>>> model.print_weights()
Scope: resnet50
conv1/conv/weights:0 (7, 7, 3, 64)
conv1/conv/biases:0 (64,)
conv1/bn/beta:0 (64,)
conv1/bn/gamma:0 (64,)
conv1/bn/moving_mean:0 (64,)
conv1/bn/moving_variance:0 (64,)
conv2/block1/0/conv/weights:0 (1, 1, 64, 256)
conv2/block1/0/conv/biases:0 (256,)
conv2/block1/0/bn/beta:0 (256,)
conv2/block1/0/bn/gamma:0 (256,)
...

>>> model.print_summary()
Scope: resnet50
Total layers: 54
Total weights: 320
Total parameters: 25,636,712

Examples

  • Comparison of different networks:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = [
    nets.MobileNet75(inputs),
    nets.MobileNet100(inputs),
    nets.SqueezeNet(inputs),
]

img = utils.load_img('cat.png', target_size=256, crop_size=224)
imgs = nets.preprocess(models, img)

with tf.Session() as sess:
    nets.pretrained(models)
    for (model, img) in zip(models, imgs):
        preds = sess.run(model, {inputs: img})
        print(utils.decode_predictions(preds, top=2)[0])
  • Transfer learning:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
outputs = tf.placeholder(tf.float32, [None, 50])
model = nets.DenseNet169(inputs, is_training=True, classes=50)

loss = tf.losses.softmax_cross_entropy(outputs, model)
train = tf.train.AdamOptimizer(learning_rate=1e-5).minimize(loss)

with tf.Session() as sess:
    nets.pretrained(model)
    for (x, y) in your_NumPy_data:  # the NHWC and one-hot format
        sess.run(train, {inputs: x, outputs: y})
  • Using multi-GPU:
inputs = tf.placeholder(tf.float32, [None, 224, 224, 3])
models = []

with tf.device('gpu:0'):
    models.append(nets.ResNeXt50(inputs))

with tf.device('gpu:1'):
    models.append(nets.DenseNet201(inputs))

from tensornets.preprocess import fb_preprocess
img = utils.load_img('cat.png', target_size=256, crop_size=224)
img = fb_preprocess(img)

with tf.Session() as sess:
    nets.pretrained(models)
    preds = sess.run(models, {inputs: img})
    for pred in preds:
        print(utils.decode_predictions(pred, top=2)[0])

Performance

Image classification

  • The top-k errors were obtained with TensorNets on ImageNet validation set and may slightly differ from the original ones. The crop size used was 224x224 for all models except NASNetAlarge (331x331), Inception3 (299x299), Inception4 (299x299), InceptionResNet2 (299x299), and ResNet50-152v2 (299x299).
    • Top-1: single center crop, top-1 error
    • Top-5: single center crop, top-5 error
    • 10-5: ten crops (1 center + 4 corners and those mirrored ones), top-5 error
    • Size: rounded the number of parameters (w/ fully-connected layers)
    • Stem: rounded the number of parameters (w/o fully-connected layers)
  • The computation times were measured on NVIDIA Tesla P100 (3584 cores, 16 GB global memory) with cuDNN 6.0 and CUDA 8.0.
    • Speed: milliseconds for inferences of 100 images
Top-1 Top-5 10-5 Size Stem Speed References
ResNet50 25.126 7.982 6.842 25.6M 23.6M 195.4 [paper] [tf-slim] [torch-fb]
[caffe] [keras]
ResNet101 23.580 7.214 6.092 44.7M 42.7M 311.7 [paper] [tf-slim] [torch-fb]
[caffe]
ResNet152 23.396 6.882 5.908 60.4M 58.4M 439.1 [paper] [tf-slim] [torch-fb]
[caffe]
ResNet50v2 24.040 6.966 5.896 25.6M 23.6M 209.7 [paper] [tf-slim] [torch-fb]
ResNet101v2 22.766 6.184 5.158 44.7M 42.6M 326.2 [paper] [tf-slim] [torch-fb]
ResNet152v2 21.968 5.838 4.900 60.4M 58.3M 455.2 [paper] [tf-slim] [torch-fb]
ResNet200v2 21.714 5.848 4.830 64.9M 62.9M 618.3 [paper] [tf-slim] [torch-fb]
ResNeXt50c32 22.260 6.190 5.410 25.1M 23.0M 267.4 [paper] [torch-fb]
ResNeXt101c32 21.270 5.706 4.842 44.3M 42.3M 427.9 [paper] [torch-fb]
ResNeXt101c64 20.506 5.408 4.564 83.7M 81.6M 877.8 [paper] [torch-fb]
WideResNet50 21.982 6.066 5.116 69.0M 66.9M 358.1 [paper] [torch]
Inception1 33.160 12.324 10.246 7.0M 6.0M 165.1 [paper] [tf-slim] [caffe-zoo]
Inception2 25.320 7.844 6.414 11.2M 10.2M 134.3 [paper] [tf-slim]
Inception3 22.054 6.242 5.000 23.9M 21.8M 314.6 [paper] [tf-slim] [keras]
Inception4 19.880 5.022 4.206 42.7M 41.2M 582.1 [paper] [tf-slim]
InceptionResNet2 19.744 4.748 3.962 55.9M 54.3M 656.8 [paper] [tf-slim]
NASNetAlarge 17.502 3.996 3.412 93.5M 89.5M 2081 [paper] [tf-slim]
NASNetAmobile 25.634 8.146 6.758 7.7M 6.7M 165.8 [paper] [tf-slim]
PNASNetlarge 17.366 3.950 3.358 86.2M 81.9M 1978 [paper] [tf-slim]
VGG16 28.732 9.950 8.834 138.4M 14.7M 348.4 [paper] [keras]
VGG19 28.744 10.012 8.774 143.7M 20.0M 399.8 [paper] [keras]
DenseNet121 25.028 7.742 6.522 8.1M 7.0M 202.9 [paper] [torch]
DenseNet169 23.824 6.824 5.860 14.3M 12.6M 219.1 [paper] [torch]
DenseNet201 22.680 6.380 5.466 20.2M 18.3M 272.0 [paper] [torch]
MobileNet25 48.418 24.208 21.196 0.5M 0.2M 34.46 [paper] [tf-slim]
MobileNet50 35.708 14.376 12.180 1.3M 0.8M 52.46 [paper] [tf-slim]
MobileNet75 31.588 11.758 9.878 2.6M 1.8M 70.11 [paper] [tf-slim]
MobileNet100 29.576 10.496 8.774 4.3M 3.2M 83.41 [paper] [tf-slim]
MobileNet35v2 39.914 17.568 15.422 1.7M 0.4M 57.04 [paper] [tf-slim]
MobileNet50v2 34.806 13.938 11.976 2.0M 0.7M 64.35 [paper] [tf-slim]
MobileNet75v2 30.468 10.824 9.188 2.7M 1.4M 88.68 [paper] [tf-slim]
MobileNet100v2 28.664 9.858 8.322 3.5M 2.3M 93.82 [paper] [tf-slim]
MobileNet130v2 25.320 7.878 6.728 5.4M 3.8M 130.4 [paper] [tf-slim]
MobileNet140v2 24.770 7.578 6.518 6.2M 4.4M 132.9 [paper] [tf-slim]
SqueezeNet 45.566 21.960 18.578 1.2M 0.7M 71.43 [paper] [caffe]

Object detection

  • The object detection models can be coupled with any network but mAPs could be measured only for the models with pre-trained weights. Note that:
    • YOLOv3VOC was trained by taehoonlee with this recipe modified as max_batches=70000, steps=40000,60000,
    • YOLOv2VOC is equivalent to YOLOv2(inputs, Darknet19),
    • TinyYOLOv2VOC: TinyYOLOv2(inputs, TinyDarknet19),
    • FasterRCNN_ZF_VOC: FasterRCNN(inputs, ZF),
    • FasterRCNN_VGG16_VOC: FasterRCNN(inputs, VGG16, stem_out='conv5/3').
  • The mAPs were obtained with TensorNets and may slightly differ from the original ones. The test input sizes were the numbers reported as the best in the papers:
    • YOLOv3, YOLOv2: 416x416
    • FasterRCNN: min_shorter_side=600, max_longer_side=1000
  • The computation times were measured on NVIDIA Tesla P100 (3584 cores, 16 GB global memory) with cuDNN 6.0 and CUDA 8.0.
    • Size: rounded the number of parameters
    • Speed: milliseconds only for network inferences of a 416x416 or 608x608 single image
    • FPS: 1000 / speed
PASCAL VOC2007 test mAP Size Speed FPS References
YOLOv3VOC (416) 0.7423 62M 24.09 41.51 [paper] [darknet] [darkflow]
YOLOv2VOC (416) 0.7320 51M 14.75 67.80 [paper] [darknet] [darkflow]
TinyYOLOv2VOC (416) 0.5303 16M 6.534 153.0 [paper] [darknet] [darkflow]
FasterRCNN_ZF_VOC 0.4466 59M 241.4 3.325 [paper] [caffe] [roi-pooling]
FasterRCNN_VGG16_VOC 0.6872 137M 300.7 4.143 [paper] [caffe] [roi-pooling]
MS COCO val2014 mAP Size Speed FPS References
YOLOv3COCO (608) 0.6016 62M 60.66 16.49 [paper] [darknet] [darkflow]
YOLOv3COCO (416) 0.6028 62M 40.23 24.85 [paper] [darknet] [darkflow]
YOLOv2COCO (608) 0.5189 51M 45.88 21.80 [paper] [darknet] [darkflow]
YOLOv2COCO (416) 0.4922 51M 21.66 46.17 [paper] [darknet] [darkflow]

News πŸ“°

Future work πŸ”₯

About

High level network definitions with pre-trained weights in TensorFlow

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 100.0%