Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cudnn v5 Support #71

Merged
merged 11 commits into from
Apr 27, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 9 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
## Apr 27, 2016

Features:

- Supported cuDNN v5
- Use the cuDNN's BatchNormalization implementation as the default engine for BN layer
- BN layer will now store running variance in its fourth blob.
- the script `python/bn_convert_style.py` is added to help converting the bn style forth and back.

## Dec 23, 2015

Features:
Expand Down
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
This branch hosts the code for the technical report ["Towards Good Practices for Very Deep Two-stream ConvNets"](http://arxiv.org/abs/1507.02159), and more.

### Updates
- Apr 27, 2016
* cuDNN v5 support, featuring the super fast WINOGrad Convolution and cuDNN implementation of BatchNormalization.
- Dec 23, 2015
* Refactored cudnn wrapper to control overall memory consumption. Will automatically find the best algorithm combination under memory constraint.
- Dec 17, 2015
Expand Down Expand Up @@ -38,8 +40,8 @@ Please see following instruction for accessing features above. More detailed doc
- Set `multi_scale` to `true` in `transform_param`
- In `transform_param`, specify `scale_ratios` as a list of floats smaller than one, default is `[1, .875, .75, .65]`
- In `transform_param`, specify `max_distort` to an integer, which will limit the aspect ratio distortion, default to `1`
- cuDNN v4
- The cuDNN v4 wrapper has optimized engines for convolution and batch normalization.
- cuDNN v5
- The cuDNN v5 wrapper has optimized engines for convolution and batch normalization.
- The solver protobuf config has a parameter `richness` which specifies the total GPU memory in MBs available to the cudnn convolution engine as workspaces. Default `richness` is 300 (300MB). Using this parameter you can control the GPU memory consumption of training, the system will find the best setup under the memory limit for you.
- Training with multiple GPUs
- Requires OpenMPI > 1.7.4 ([Why?](https://www.open-mpi.org/faq/?category=runcuda)). **Remember to compile your OpenMPI with option `--with-cuda`**
Expand Down
4 changes: 1 addition & 3 deletions include/caffe/common_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ class BNLayer : public Layer<Dtype> {


#if defined(USE_CUDNN)
#if CUDNN_VERSION_MIN(4, 0, 0)
#if CUDNN_VERSION_MIN(5, 0, 0)
/**
* @brief cuDNN implementation of BNLayer.
* Fallback to BNLayer for CPU mode.
Expand Down Expand Up @@ -767,8 +767,6 @@ class CuDNNBNLayer : public BNLayer<Dtype> {
cudnnTensorDescriptor_t top_desc_;
cudnnTensorDescriptor_t bn_param_desc_;

Blob<Dtype> scale_buf_;
Blob<Dtype> bias_buf_;
Blob<Dtype> save_mean_;
Blob<Dtype> save_inv_variance_;
};
Expand Down
3 changes: 3 additions & 0 deletions include/caffe/neuron_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,7 @@ class CuDNNReLULayer : public ReLULayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activation_desc_;
};
#endif

Expand Down Expand Up @@ -583,6 +584,7 @@ class CuDNNSigmoidLayer : public SigmoidLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activation_desc_;
};
#endif

Expand Down Expand Up @@ -668,6 +670,7 @@ class CuDNNTanHLayer : public TanHLayer<Dtype> {
cudnnHandle_t handle_;
cudnnTensorDescriptor_t bottom_desc_;
cudnnTensorDescriptor_t top_desc_;
cudnnActivationDescriptor_t activation_desc_;
};
#endif

Expand Down
15 changes: 11 additions & 4 deletions include/caffe/util/cudnn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#define CAFFE_UTIL_CUDNN_H_
#ifdef USE_CUDNN

#include <cudnn.h>
#include "cudnn.h"

#include "caffe/common.hpp"
#include "caffe/proto/caffe.pb.h"
Expand Down Expand Up @@ -92,7 +92,7 @@ inline void createFilterDesc(cudnnFilterDescriptor_t* desc,
int n, int c, int h, int w) {
CUDNN_CHECK(cudnnCreateFilterDescriptor(desc));
CUDNN_CHECK(cudnnSetFilter4dDescriptor(*desc, dataType<Dtype>::type,
n, c, h, w));
CUDNN_TENSOR_NCHW, n, c, h, w));
}

template <typename Dtype>
Expand Down Expand Up @@ -123,8 +123,15 @@ inline void createPoolingDesc(cudnnPoolingDescriptor_t* pool_desc,
LOG(FATAL) << "Unknown pooling method.";
}
CUDNN_CHECK(cudnnCreatePoolingDescriptor(pool_desc));
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode, h, w,
pad_h, pad_w, stride_h, stride_w));
#if CUDNN_VERSION_MIN(5, 0, 0)
CUDNN_CHECK(cudnnSetPooling2dDescriptor(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w,
pad_h, pad_w, stride_h, stride_w));
#else
CUDNN_CHECK(cudnnSetPooling2dDescriptor_v4(*pool_desc, *mode,
CUDNN_PROPAGATE_NAN, h, w,
pad_h, pad_w, stride_h, stride_w));
#endif
}

} // namespace cudnn
Expand Down
42 changes: 42 additions & 0 deletions python/bn_convert_style.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import numpy as np
import sys
import os
import os.path as osp
from argparse import ArgumentParser

import caffe


def main(args):
net = caffe.Net(args.model, args.weights, caffe.TEST)
conversion = args.conversion
eps = args.epsilon
for name, param in net.params.iteritems():
if name.endswith('_bn'):
if conversion == 'var_to_inv_std':
var = param[3].data
inv_std = 1. / np.sqrt(var + eps)
param[3].data[...] = inv_std
elif conversion == 'inv_std_to_var':
inv_std = param[3].data
var = np.power(inv_std, -2) - eps
param[3].data[...] = var
else:
raise ValueError("Unknown conversion type {}".format(conversion))
net.save(args.output)


if __name__ == '__main__':
parser = ArgumentParser(
description="This script converts between two styles of BN models. "
"Specifically, in history we have two version of BN implementation, one storing running variance"
"the other storing running inverse std.")
parser.add_argument('model', help="The deploy prototxt")
parser.add_argument('weights', help="The caffemodel")
parser.add_argument('--output', '-o', help="Output caffemodel")
parser.add_argument('--conversion', type=str, default="inv_std_to_var",
help='can be "var_to_inv_std" or "inv_std_to_var"')
parser.add_argument('--epsilon', type=float, default=1e-5,
help='the epsilon in the inverse, default to 1e-5')
args = parser.parse_args()
main(args)
30 changes: 0 additions & 30 deletions python/bn_var_to_inv_std.py

This file was deleted.

20 changes: 4 additions & 16 deletions src/caffe/layer_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,35 +168,23 @@ shared_ptr<Layer<Dtype> > GetBNLayer(const LayerParameter& param) {
if (engine == BNParameter_Engine_DEFAULT) {
engine = BNParameter_Engine_CAFFE;
#if defined(USE_CUDNN)
#if CUDNN_VERSION_MIN(4, 0, 0)
// TODO : Currently we use caffe as the default engine, due to the performance issues with NV's implementation.
// Will switch back when this get fixed.

// engine = BNParameter_Engine_CUDNN;
#if CUDNN_VERSION_MIN(5, 0, 0)
engine = BNParameter_Engine_CUDNN;
#endif
#endif
}
#if defined(USE_CUDNN)
#if CUDNN_VERSION_MIN(4, 0, 0)
if (engine == BNParameter_Engine_CUDNN && param.bn_param().frozen()) {
LOG(WARNING) << "Layer " << param.name() << " switches back to CAFFE engine"
<< " as CUDNN engine doesn't support frozen.";
engine = BNParameter_Engine_CAFFE;
}
#endif
#endif
if (engine == BNParameter_Engine_CAFFE) {
LOG(INFO) << "Layer " << param.name() << " is using CAFFE engine.";
return shared_ptr<Layer<Dtype> >(new BNLayer<Dtype>(param));
#if defined(USE_CUDNN)
#if CUDNN_VERSION_MIN(4, 0, 0)
#if CUDNN_VERSION_MIN(5, 0, 0)
} else if (engine == BNParameter_Engine_CUDNN) {
LOG(INFO) << "Layer " << param.name() << " is using CUDNN engine.";
return shared_ptr<Layer<Dtype> >(new CuDNNBNLayer<Dtype>(param));
#endif
#endif
} else {
LOG(FATAL) << "Layer " << param.name() << " has unknown engine.";
LOG(FATAL) << "Layer " << param.name() << " calls cuDNN engine, but cuDNN version higher than 5.0 is not found";
}
}

Expand Down
57 changes: 42 additions & 15 deletions src/caffe/layers/bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,31 @@ void BNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
this->blobs_[2]->mutable_cpu_data());
// moving average variance
this->blobs_[3].reset(new Blob<Dtype>(shape));
caffe_set(this->blobs_[3]->count(), Dtype(1),
caffe_set(this->blobs_[3]->count(), frozen_ ? Dtype(1) : Dtype(0),
this->blobs_[3]->mutable_cpu_data());
}
this->param_propagate_down_.resize(this->blobs_.size(), true);

// runing average stats does not use weight decay and learning rate
while (this->layer_param_.param_size() < 4){
this->layer_param_.mutable_param()->Add();
}
this->layer_param_.mutable_param(2)->set_lr_mult(Dtype(0));
this->layer_param_.mutable_param(2)->set_decay_mult(Dtype(0));

this->layer_param_.mutable_param(3)->set_lr_mult(Dtype(0));
this->layer_param_.mutable_param(3)->set_decay_mult(Dtype(0));

// shutdown scale and bias update in frozen mode
if (this->frozen_){
// slope
this->layer_param_.mutable_param(0)->set_lr_mult(Dtype(0));
this->layer_param_.mutable_param(0)->set_decay_mult(Dtype(0));

// bias
this->layer_param_.mutable_param(1)->set_lr_mult(Dtype(0));
this->layer_param_.mutable_param(1)->set_decay_mult(Dtype(0));
}
}

template <typename Dtype>
Expand Down Expand Up @@ -118,6 +139,7 @@ void BNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
caffe_copy(batch_statistic_.count(), this->blobs_[3]->cpu_data(),
batch_statistic_.mutable_cpu_data());
} else {
// calculate batch variance
caffe_powx(broadcast_buffer_.count(), const_top_data, Dtype(2),
broadcast_buffer_.mutable_cpu_data());
caffe_cpu_gemv<Dtype>(CblasNoTrans, num_ * channels_, height_ * width_,
Expand All @@ -127,19 +149,20 @@ void BNLayer<Dtype>::Forward_cpu(const vector<Blob<Dtype>*>& bottom,
caffe_cpu_gemv<Dtype>(CblasTrans, num_, channels_, Dtype(1) / num_,
spatial_statistic_.cpu_data(), batch_sum_multiplier_.cpu_data(),
Dtype(0), batch_statistic_.mutable_cpu_data());
// Add eps
caffe_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_cpu_data());
// Inverse standard deviation
caffe_powx(batch_statistic_.count(), batch_statistic_.cpu_data(),
Dtype(-0.5), batch_statistic_.mutable_cpu_data());

// Add to the moving average
if (!frozen_) {
caffe_cpu_axpby(batch_statistic_.count(),
Dtype(1) - bn_momentum_, batch_statistic_.cpu_data(),
bn_momentum_, this->blobs_[3]->mutable_cpu_data());
}
caffe_cpu_axpby(batch_statistic_.count(),
Dtype(1) - bn_momentum_, batch_statistic_.cpu_data(),
bn_momentum_, this->blobs_[3]->mutable_cpu_data());
}

// Add eps
caffe_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_cpu_data());
// Inverse standard deviation
caffe_powx(batch_statistic_.count(), batch_statistic_.cpu_data(),
Dtype(-0.5), batch_statistic_.mutable_cpu_data());

// Broadcast the inverse std
caffe_cpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_, channels_, 1,
Dtype(1), batch_sum_multiplier_.cpu_data(), batch_statistic_.cpu_data(),
Expand Down Expand Up @@ -190,10 +213,14 @@ void BNLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
if (propagate_down[0]) {
const Dtype* const_top_diff = top[0]->cpu_diff();
Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
// Use the moving average inverse std
// Use the moving average variance
caffe_copy(batch_statistic_.count(), this->blobs_[3]->cpu_data(),
batch_statistic_.mutable_cpu_data());
// Multiple slope with inverse std
caffe_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_cpu_data());
caffe_powx(batch_statistic_.count(), batch_statistic_.cpu_data(),
Dtype(-0.5), batch_statistic_.mutable_cpu_data());
// Divide slope with std
caffe_mul(batch_statistic_.count(), this->blobs_[0]->cpu_data(),
batch_statistic_.cpu_data(), batch_statistic_.mutable_cpu_data());
// Broadcast
Expand Down Expand Up @@ -315,4 +342,4 @@ STUB_GPU(BNLayer);

INSTANTIATE_CLASS(BNLayer);

} // namespace caffe
} // namespace caffe
26 changes: 15 additions & 11 deletions src/caffe/layers/bn_layer.cu
Original file line number Diff line number Diff line change
Expand Up @@ -67,19 +67,19 @@ void BNLayer<Dtype>::Forward_gpu(const vector<Blob<Dtype>*>& bottom,
caffe_gpu_gemv<Dtype>(CblasTrans, num_, channels_, Dtype(1) / num_,
spatial_statistic_.gpu_data(), batch_sum_multiplier_.gpu_data(),
Dtype(0), batch_statistic_.mutable_gpu_data());
// Add eps
caffe_gpu_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_gpu_data());
// Inverse standard deviation
caffe_gpu_powx(batch_statistic_.count(), batch_statistic_.gpu_data(),
Dtype(-0.5), batch_statistic_.mutable_gpu_data());

// Add to the moving average
if (!frozen_) {
caffe_gpu_axpby(batch_statistic_.count(),
Dtype(1) - bn_momentum_, batch_statistic_.gpu_data(),
bn_momentum_, this->blobs_[3]->mutable_gpu_data());
}
caffe_gpu_axpby(batch_statistic_.count(),
Dtype(1) - bn_momentum_, batch_statistic_.gpu_data(),
bn_momentum_, this->blobs_[3]->mutable_gpu_data());
}

// Add eps
caffe_gpu_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_gpu_data());
// Inverse standard deviation
caffe_gpu_powx(batch_statistic_.count(), batch_statistic_.gpu_data(),
Dtype(-0.5), batch_statistic_.mutable_gpu_data());
// Broadcast the inverse std
caffe_gpu_gemm<Dtype>(CblasNoTrans, CblasNoTrans, num_, channels_, 1,
Dtype(1), batch_sum_multiplier_.gpu_data(), batch_statistic_.gpu_data(),
Expand Down Expand Up @@ -133,6 +133,10 @@ void BNLayer<Dtype>::Backward_gpu(const vector<Blob<Dtype>*>& top,
// Use the moving average variance
caffe_copy(batch_statistic_.count(), this->blobs_[3]->gpu_data(),
batch_statistic_.mutable_gpu_data());
caffe_gpu_add_scalar(batch_statistic_.count(), bn_eps_,
batch_statistic_.mutable_gpu_data());
caffe_gpu_powx(batch_statistic_.count(), batch_statistic_.gpu_data(),
Dtype(-0.5), batch_statistic_.mutable_gpu_data());
// Multiple slope with inverse std
caffe_gpu_mul(batch_statistic_.count(), this->blobs_[0]->gpu_data(),
batch_statistic_.gpu_data(), batch_statistic_.mutable_gpu_data());
Expand Down
10 changes: 1 addition & 9 deletions src/caffe/layers/cudnn_bn_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,14 @@
#include "caffe/util/math_functions.hpp"
#include "caffe/vision_layers.hpp"

#if CUDNN_VERSION_MIN(4, 0, 0)
#if CUDNN_VERSION_MIN(5, 0, 0)

namespace caffe {

template <typename Dtype>
void CuDNNBNLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom,
const vector<Blob<Dtype>*>& top) {
BNLayer<Dtype>::LayerSetUp(bottom, top);
if (this->bn_eps_ < CUDNN_BN_MIN_EPSILON) {
LOG(WARNING) << "bn_eps is set to CUDNN_BN_MIN_EPSILON.";
// Merely setting as CUDNN_BN_MIN_EPSILON fails the check due to
// float / double precision problem.
this->bn_eps_ = CUDNN_BN_MIN_EPSILON * 10;
}
scale_buf_.ReshapeLike(*(this->blobs_[0]));
bias_buf_.ReshapeLike(*(this->blobs_[1]));
save_mean_.ReshapeLike(*(this->blobs_[2]));
save_inv_variance_.ReshapeLike(*(this->blobs_[3]));

Expand Down