-
Notifications
You must be signed in to change notification settings - Fork 1.6k
🚀 Tenary Weight and DoReFa-Net in TensorFlow (TensorLayer) #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
tensorlayer/layers/binary.py
Outdated
| import tensorflow as tf | ||
| from tensorflow.python.framework import ops | ||
|
|
||
| bitW = 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you avoid using global variables to setup the parameters?
tensorlayer/layers/binary.py
Outdated
| with G.gradient_override_map({"Round": "Identity"}): | ||
| return tf.round(x * n) / n | ||
|
|
||
| def fw(x, force_quantization=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we use a more meaningful names? If this function is used only internally, we better to start the name with a underscore.
| import os, time | ||
| import tensorflow as tf | ||
| import tensorlayer as tl | ||
| from tensorlayer.layers import * |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We prefer to avoiding "import *" in new committed files.
| @@ -1,11 +1,19 @@ | |||
| # -*- coding: utf-8 -*- | |||
| from .core import * | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't use import * for new code.
fiexed some bugs of format
add bitw and bita for apis
|
@zsdonghao as we are having a large number of examples, shall we actually start organizing the examples into different folders? |
| @@ -0,0 +1,324 @@ | |||
| #! /usr/bin/python | |||
| # -*- coding: utf-8 -*- | |||
| """Reimplementation of the TensorFlow official CIFAR-10 CNN tutorials. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") | ||
|
|
||
| ## Example to visualize data | ||
| # img, label = read_and_decode("train.cifar10", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| @@ -0,0 +1,319 @@ | |||
| #! /usr/bin/python | |||
| # -*- coding: utf-8 -*- | |||
| """Reimplementation of the TensorFlow official CIFAR-10 CNN tutorials. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| data_to_tfrecord(images=X_train, labels=y_train, filename="train.cifar10") | ||
| data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") | ||
|
|
||
| ## Example to visualize data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| @@ -0,0 +1,318 @@ | |||
| #! /usr/bin/python | |||
| # -*- coding: utf-8 -*- | |||
| """Reimplementation of the TensorFlow official CIFAR-10 CNN tutorials. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| data_to_tfrecord(images=X_test, labels=y_test, filename="test.cifar10") | ||
|
|
||
| ## Example to visualize data | ||
| # img, label = read_and_decode("train.cifar10", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this comment.
| Examples | ||
| --------- | ||
| >>> net = tl.layers.InputLayer(x, name='input') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the example here should be modified.
| Examples | ||
| --------- | ||
| >>> net = tl.layers.InputLayer(x, name='input') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the example here should be modified
| self.inputs = prev_layer.outputs | ||
| if act is None: | ||
| act = tf.identity | ||
| logging.info("BinaryConv2d %s: n_filter:%d filter_size:%s strides:%s pad:%s act:%s" % (self.name, n_filter, str(filter_size), str(strides), padding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change BinaryConv2d to DorefaConv2d
|
|
||
| n_in = int(self.inputs.get_shape()[-1]) | ||
| self.n_units = n_units | ||
| logging.info("BinaryDenseLayer %s: %d %s" % (self.name, self.n_units, act.__name__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change BinaryDenseLayer to DorefaDenseLayer
| self.inputs = prev_layer.outputs | ||
| if act is None: | ||
| act = tf.identity | ||
| logging.info("BinaryConv2d %s: n_filter:%d filter_size:%s strides:%s pad:%s act:%s" % (self.name, n_filter, str(filter_size), str(strides), padding, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change BinaryConv2d to TenaryConv2d
|
|
||
| n_in = int(self.inputs.get_shape()[-1]) | ||
| self.n_units = n_units | ||
| logging.info("BinaryDenseLayer %s: %d %s" % (self.name, self.n_units, act.__name__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change BinaryDenseLayer to TenaryDenseLayer
| return tf.sign(x) | ||
|
|
||
|
|
||
| def _quantize_dorefa(x, k): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return tf.round(x * n) / n | ||
|
|
||
|
|
||
| def _quantize_weight(x, bitW, force_quantization=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return 2 * _quantize_dorefa(x, bitW) - 1 | ||
|
|
||
|
|
||
| def _quantize_active(x, bitA): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return _quantize_dorefa(x, bitA) | ||
|
|
||
|
|
||
| def _cabs(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return tf.minimum(1.0, tf.abs(x), name='cabs') | ||
|
|
||
|
|
||
| def _compute_threshold(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return threshold | ||
|
|
||
|
|
||
| def _compute_alpha(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| return alpha | ||
|
|
||
|
|
||
| def _tenary_opration(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
need comment
| # ref: https://github.com/itayhubara/BinaryNet.tf/blob/master/models/BNN_cifar10.py | ||
| with tf.variable_scope("binarynet", reuse=reuse): | ||
| net = tl.layers.InputLayer(x, name='input') | ||
| net = tl.layers.DorefaConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How does DorefaConv2d know that 32 is for n_filter but not for bitW or bitA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it crashed with
$ python3 example/tutorial_dorefanet_mnist_cnn.py
[TL] Load or Download MNIST > data/mnist
[TL] data/mnist/train-images-idx3-ubyte.gz
[TL] data/mnist/t10k-images-idx3-ubyte.gz
2018-03-22 21:30:07.666529: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2
[TL] InputLayer binarynet/input: (128, 28, 28, 1)
Traceback (most recent call last):
File "example/tutorial_dorefanet_mnist_cnn.py", line 48, in <module>
net_train = model(x, is_train=True, reuse=False)
File "example/tutorial_dorefanet_mnist_cnn.py", line 25, in model
net = tl.layers.DorefaConv2d(net, 32, (5, 5), (1, 1), padding='SAME', b_init=None, name='bcnn1')
File "/home/vagrant/.local/lib/python3.6/site-packages/tensorlayer/layers/binary.py", line 687, in __init__
act.__name__))
TypeError: %d format: a number is required, not tuple
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I forgot it and I have fixed the bug
|
Since our functions have long argument lists, and many of them have default values, I suggest that we always use keyword style argument passing (except for small helper functions that 1,2 or 3 arguments). |
tensorlayer/layers/binary.py
Outdated
| class DorefaConv2d(Layer): | ||
| """ | ||
| The :class:`BinaryConv2d` class is a 2D binary CNN layer, which weights are either -1 or 1 while inferencing. | ||
| """The :class:`DorefaConv2d` class is a binary fully connected layer, which weights are 'bitW' bits and the input of the previous layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be convolutional layer
fix some issue and delete some comment
|
bnn is a excellent work in the compression of neuron network but it can not get a satisfied accuracy on relative large datasets, in order to solve the problem, tenary weight networks and dorefa were put forward. I add 4 apis for tensorlayer, Tenary Denselayer, TenaryConv2d, DorefaDenselayer, and DorefaConv2d . I perform 6 experiment based on mnist and cifar10,the details are in thr titorials. |
|
@XJTUWYD Hi, i suggest you to add this comment into this issue #416, more people can see it. and include your previous comment: I add two compress strategies tenary weight network and dorefa-net into tensorlayer, did two experiments to compare the accuracy of different compress strategies based on mnist and cifar-10.
|
* compress add four apis TenaryConv2d,TenaryDenseLayer,DorefaConv2d,DorefaDenselyLayer and build different tutorials for bnn,twn,dorefa based on mnist and cifar10 datasets * four apis four apis * fiexed some bugs of format fiexed some bugs of format * add bitw and bita for apis add bitw and bita for apis * Add files via upload * Add files via upload * Add files via upload * do some explain about twn and dorefa * fix some issue and delete some comment fix some issue and delete some comment * add some comment * use yapf
I add two compress strategies tenary weight network and dorefa-net into tensorlayer, did two experiments to compare the accuracy of different compress strategies based on mnist and cifar-10.
the result of the experiment is below:
related issue: #416