Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wy1iu committed Mar 5, 2018
1 parent fef27c9 commit 81f6e5b
Show file tree
Hide file tree
Showing 33 changed files with 6,356 additions and 2 deletions.
137 changes: 135 additions & 2 deletions README.md
@@ -1,2 +1,135 @@
# SphereNet
Implementation for <Deep Hyperspherical Learning> in NIPS'17.
# Deep Hyperspherical Learning

By Weiyang Liu, Yan-Ming Zhang, Xingguo Li, Zhiding Yu, Bo Dai, Tuo Zhao, Le Song

### License

SphereNet is released under the MIT License (refer to the LICENSE file for details).

### Code to be Updated
- [x] SphereNet: a neural network that learns on hyperspheres </li>
- [ ] SphereResNet: an adaptation of SphereConv to residual networks </li>
- [ ] Feature visualization on MNIST </li>

### Contents
0. [Introduction](#introduction)
0. [Citation](#citation)
0. [Requirements](#requirements)
0. [Usage](#usage)
0. [Results](#results)
0. [Notes](#notes)
0. [Third-party re-implementation](#third-party-re-implementation)
0. [Contact](#contact)


### Introduction

The repository contains an example Tensorflow implementation for SphereNets. SphereNets are introduced in the NIPS 2017 paper "[Deep Hyperspherical Learning](http://wyliu.com/papers/LiuNIPS17.pdf)" ([arXiv](https://arxiv.org/abs/1711.03189)). SphereNets are able to converge faster and more stably than its CNN counterparts, while yielding to comparable or even better classification accuracy.

Hyperspherical learning is inspired by an interesting obvervation of the 2D Fourier transform. From the image below, we could see that magnitude information is not crucial for recognizing the identity, but phase information is very important for recognition. By droping the magnitude information, SphereNets can reduce the learning space and therefore gain more convergence speed. *Hypersphereical learning provides a new framework to improve the convolutional neural networks.*

<img src="asserts/2dfourier.png" width="52%" height="52%">

The features learned by SphereNets are also very interesting. The 2D features of SphereNets learned on MNIST are more compact and have larger margin between classes. From the image below, we can see that local behavior of convolutions could lead to dramatic difference in final features, even if they are supervised by the same standard softmax loss. *Hypersphereical learning provides a new perspective to think about convolutions and deep feature learning.*

<img src="asserts/mnist_featvis.jpg" width="51%" height="51%">

Besides, the hyperspherical learning also leads to a well-performing normalization technique, SphereNorm. SphereNorm basically can be viewed as SphereConv operator in our implementation.

### Citation

If you find our work useful in your research, please consider to cite:

@inproceedings{liu2017deep,
title={Deep Hyperspherical Learning},
author={Liu, Weiyang and Zhang, Yan-Ming and Li, Xingguo and Yu, Zhiding and Dai, Bo and Zhao, Tuo and Song, Le},
booktitle={Advances in Neural Information Processing Systems},
pages={3953--3963},
year={2017}
}


### Requirements
1. `Python 2.7`
2. `TensorFlow` (Tested on version 1.01)
3. `numpy`


### Usage

#### Part 1: Setup
- Clone the repositary and download the training set.

```Shell
git clone https://github.com/wy1iu/SphereNet.git
cd SphereNet
./dataset_setup.sh
```

#### Part 2: Train Baseline/SphereNets

- To train the baseline model, please open `baseline/train_baseline.py` and assign an available GPU. The default hyperparameters are exactly the same with SphereNets.

```Shell
python baseline/train_baseline.py
```

- To train the SphereNet, please open `train_spherenet.py` and assign an available GPU.

```Shell
python train_spherenet.py
```


### Configuration
The default setting of SphereNet is Cosine SphereConv + Standard Softmax Loss. To change the type of SphereConv, please open the `spherenet.py` and change the `norm` variable.

- If `norm` is set to `none`, then the network will use original convolution and become standard CNN.
- If `norm` is set to `linear`, then the SphereNet will use linear SphereConv.
- If `norm` is set to `cosine`, then the SphereNet will use cosine SphereConv.
- If `norm` is set to `sigmoid`, then the SphereNet will use sigmoid SphereConv.
- If `norm` is set to `lr_sigmoid`, then the SphereNet will use learnable sigmoid SphereConv.

The `w_norm` variable can also be changed similarly in order to use the weight-normalized softmax loss (combined with different SphereConv). By setting `w_norm` to `none`, we will use the standard softmax loss.

There are some examples of setting these two variables provided in the `examples/` foloder.


### Results
#### Part 1: Convergence

The convergence curves for baseline CNN and several types of SphereNets are given as follows.
<img src="asserts/convergence.jpg" width="51%" height="51%">


#### Part 2: Best testing accuracy on CIFAR-10

- Baseline (standard CNN with standard softmax loss): 90.86%
- SphereNet with cosine SphereConv and standard softmax loss: 91.31%
- SphereNet with linear SphereConv and standard softmax loss: 91.65%
- SphereNet with sigmoid SphereConv and standard softmax loss: 91.81%
- SphereNet with learnable sigmoid SphereConv and standard softmax loss: 91.66%
- SphereNet with cosine SphereConv and weight-normalized softmax loss: 91.44%

#### Part 3: Training log

- Baseline: [here](baseline_training.log)
- SphereNet with cosine SphereConv and standard softmax loss: [here](results/spherenet_cos_standard_training.log).
- SphereNet with linear SphereConv and standard softmax loss: [here](results/spherenet_linear_standard_training.log).
- SphereNet with sigmoid SphereConv and standard softmax loss: [here](results/spherenet_sigmoid_standard_training.log).
- SphereNet with learnable sigmoid SphereConv and standard softmax loss: [here](results/spherenet_lr_sigmoid_standard_training.log).
- SphereNet with cosine SphereConv and weight-normalized softmax loss: [here](results/spherenet_cos_wnsoftmax_training.log).

### Notes
- Empirically, SphereNets have more accuracy gain with larger filter number. If the filter number is very small, SphereNets may yield slightly worse accuracy but can still achieve much faster convergence.
- SphereConv may be useful for RNNs and deep Q-learning where better convergence can help.
- By adding rescaling factors to SphereConv and make them learnable in order for the SphereNorm to degrade to the original convolution, we present a new normalization technique, SphereNorm. SphereNorm does not contradict with the BatchNorm, and can be used either with or without BatchNorm

### Third-party re-implementation
- TensorFlow: [code](https://github.com/unixpickle/spherenet) by [unixpickle](https://github.com/unixpickle).


### Contact

- [Weiyang Liu](https://wyliu.com)

Binary file added asserts/2dfourier.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asserts/MNIST_featvis.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added asserts/convergence.jpg
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
158 changes: 158 additions & 0 deletions baseline/baseline_cnn.py
@@ -0,0 +1,158 @@
import tensorflow as tf
import numpy as numpy

class Baseline_CNN():
def get_conv_filter(self, shape, reg, stddev):
init = tf.random_normal_initializer(stddev=stddev)
if reg:
regu = tf.contrib.layers.l2_regularizer(self.wd)
filt = tf.get_variable('filter', shape, initializer=init,regularizer=regu)
else:
filt = tf.get_variable('filter', shape, initializer=init)

return filt

def get_bias(self, dim, init_bias, name):
with tf.variable_scope(name):
init = tf.constant_initializer(init_bias)
regu = tf.contrib.layers.l2_regularizer(self.wd)
bias = tf.get_variable('bias', dim, initializer=init, regularizer=regu)

return bias

def batch_norm(self, x, n_out, phase_train):
with tf.variable_scope('bn'):

gamma = self.get_bias(n_out, 1.0, 'gamma')
beta = self.get_bias(n_out, 0.0, 'beta')

batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.999)

def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)

mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
return tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)

def _max_pool(self, bottom, name):
return tf.nn.max_pool(bottom, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1],
padding='SAME', name=name)

def _get_filter_norm(self, filt):
eps = 1e-4
return tf.sqrt(tf.reduce_sum(filt*filt, [0, 1, 2], keep_dims=True)+eps)

def _get_input_norm(self, bottom, ksize, pad):
eps = 1e-4
shape = [ksize, ksize, bottom.get_shape()[3], 1]
filt = tf.ones(shape)
input_norm = tf.sqrt(tf.nn.conv2d(bottom*bottom, filt, [1,1,1,1], padding=pad)+eps)
return input_norm


def _add_orthogonal_constraint(self, filt, n_filt):

filt = tf.reshape(filt, [-1, n_filt])
inner_pro = tf.matmul(tf.transpose(filt), filt)

loss = 2e-4*tf.nn.l2_loss(inner_pro-tf.eye(n_filt))
tf.add_to_collection('orth_constraint', loss)

def _conv_layer(self, bottom, ksize, n_filt, is_training, name, stride=1, bn=False, relu=True, pad='SAME', norm='none', reg=False, orth=False, w_norm='none'):

with tf.variable_scope(name) as scope:
n_input = bottom.get_shape().as_list()[3]
shape = [ksize, ksize, n_input, n_filt]
print("shape of filter %s: %s" % (name, str(shape)))

filt = self.get_conv_filter(shape, reg, stddev=tf.sqrt(2.0/tf.to_float(ksize*ksize*n_input)))
conv = tf.nn.conv2d(bottom, filt, [1, stride, stride, 1], padding=pad)
xnorm = self._get_input_norm(bottom, ksize, pad)
wnorm = self._get_filter_norm(filt)

if w_norm == 'linear':
conv = conv/wnorm
conv = -0.63662*tf.acos(conv)+1
elif w_norm == 'cosine':
conv = conv/wnorm
elif w_norm == 'sigmoid':
k_value_w = 0.3
constant_coeff_w = (1 + numpy.exp(-numpy.pi/(2*k_value_w)))/(1 - numpy.exp(-numpy.pi/(2*k_value_w)))
conv = conv/wnorm
conv = constant_coeff_w*(1-tf.exp(tf.acos(conv)/k_value_w-numpy.pi/(2*k_value_w)))/(1+tf.exp(tf.acos(conv)/k_value_w-numpy.pi/(2*k_value_w)))
elif w_norm == 'none':
pass

if norm == 'linear':
conv = conv/xnorm
conv = conv/wnorm
conv = -0.63662*tf.acos(conv)+1
elif norm == 'cosine':
conv = conv/xnorm
conv = conv/wnorm
elif norm == 'sigmoid':
k_value = 0.3
constant_coeff = (1 + numpy.exp(-numpy.pi/(2*k_value)))/(1 - numpy.exp(-numpy.pi/(2*k_value)))
conv = conv/xnorm
conv = conv/wnorm
conv = constant_coeff*(1-tf.exp(tf.acos(conv)/k_value-numpy.pi/(2*k_value)))/(1+tf.exp(tf.acos(conv)/k_value-numpy.pi/(2*k_value)))
elif norm == 'lr_sigmoid':
k_value_lr = tf.get_variable('k_value_lr', n_filt,
initializer=tf.constant_initializer(0.7),
dtype=tf.float32)
k_value_lr = tf.abs(k_value_lr) + 0.05
constant_coeff = (1 + tf.exp(-numpy.pi/(2*k_value_lr)))/(1 - tf.exp(-numpy.pi/(2*k_value_lr)))
conv = conv/xnorm
conv = conv/wnorm
conv = constant_coeff*(1-tf.exp(tf.acos(conv)/k_value_lr-numpy.pi/(2*k_value_lr)))/(1+tf.exp(tf.acos(conv)/k_value_lr-numpy.pi/(2*k_value_lr)))
elif norm == 'none':
pass

if orth:
self._add_orthogonal_constraint(filt, n_filt)

if bn:
conv = self.batch_norm(conv, n_filt, is_training)

if relu:
return tf.nn.relu(conv)
else:
return conv

# Input should be an rgb image [batch, height, width, 3]
def build(self, rgb, n_class, is_training):
self.wd = 5e-4

feat = (rgb - 127.5)/128.0

ksize = 3
n_layer = 3

#32X32
n_out = 128
for i in range(n_layer):
feat = self._conv_layer(feat, ksize, n_out, is_training, name="conv1_"+str(i), bn=True, relu=True, pad='SAME', norm='none', reg=True, orth=False)
feat = self._max_pool(feat, 'pool1')

#16X16
n_out = 192
for i in range(n_layer):
feat = self._conv_layer(feat, ksize, n_out, is_training, name="conv2_"+str(i), bn=True, relu=True, pad='SAME', norm='none', reg=True, orth=False)
feat = self._max_pool(feat, 'pool2')

#8X8
n_out = 256
for i in range(n_layer):
feat = self._conv_layer(feat, ksize, n_out, is_training, name="conv3_"+str(i), bn=True, relu=True, pad='SAME', norm='none', reg=True, orth=False)
feat = self._max_pool(feat, 'pool3')

self.fc6 = self._conv_layer(feat, 4, 256, is_training, "fc6", bn=True, relu=False, pad='VALID', norm='none', reg=True, orth=False)

self.score = self._conv_layer(self.fc6, 1, n_class, is_training, "score", bn=False, relu=False, pad='VALID', norm='none', reg=True, orth=False, w_norm='none')

self.pred = tf.squeeze(tf.argmax(self.score, axis=3))
Binary file added baseline/baseline_cnn.pyc
Binary file not shown.
4 changes: 4 additions & 0 deletions baseline/baseline_models/.gitignore
@@ -0,0 +1,4 @@
# Ignore everything in this directory
*
# Except this file
!.gitignore

0 comments on commit 81f6e5b

Please sign in to comment.