Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[December 2014] benchmarking Imagenet winners #35

Open
soumith opened this issue Dec 8, 2014 · 40 comments
Open

[December 2014] benchmarking Imagenet winners #35

soumith opened this issue Dec 8, 2014 · 40 comments

Comments

@soumith
Copy link
Owner

soumith commented Dec 8, 2014

Time to take these benchmarks forward to a more meaningful metric (it's taken so long, but it's after all a side project for fun).

I've added benchmarks for the following networks:
VGG Model A (2nd place Imagenet 2014)
Overfeat Fast (1st place Imagenet 2013 Detection)
AlexNet (the holy network)

So far I covered only two sets of kernels:

  • Torch7 CUNN
  • NVIDIA CuDNN R1

In the next week or two, I will try to get Caffe in there as well (if there's a volunteer, this will happen faster, if not it will happen at my own pace, I am not exactly well versed with Caffe's configuration files).
I will try to get cuda-convnet2 as well, but it is failing some numFilters assertions (it supports only certain multiples of input/output plane configurations), will have to look into it.
I am looking for a volunteer to do this on the Theano side, @benanne ??
For the rest of the libraries, they are mostly poorly documented, and it took me a lot of effort to get the first round of benchmarks. Their kernels are not really at the top of the table either, so I will skip them.

Now, coming to GoogleNet, I coded the config in torch (https://github.com/soumith/convnet-benchmarks/blob/master/torch7/imagenet_winners/googlenet.lua), but it is too big to fit on a Titan Black (even with batch-size 1), I will try to benchmark it across 4 K40 cards (12GB each), I have a server that has the 4 cards on a single machine, and Torch7 supports Multi-GPU now, lets see, it will be an interesting exercise.

Have a happy NIPS everyone, the day CuDNN R2 releases I will have the numbers as well (the day is in the near future I believe)

@soumith
Copy link
Owner Author

soumith commented Dec 8, 2014

Here are the CuDNN, Torch7-CuNN numbers:
https://github.com/soumith/convnet-benchmarks/blob/master/torch7/imagenet_winners/output.log

The same directory contains the benchmark script, and the model configurations for all the networks.

@liuliu
Copy link
Contributor

liuliu commented Dec 8, 2014

(y)

@benanne
Copy link

benanne commented Dec 8, 2014

I can try coding up those first 3 probably, but GoogLeNet in Theano is still a challenge at the moment, and it would require some tricks that would slow it down unnecessarily (like the 3x3s1 pooling that retains the size of the input, for example). So it's probably not worth spending time on that just yet. I have a lot on my plate at the moment but hopefully I'll find some time for it this week.

EDIT: actually, looking into this, Theano's support for 'same' convolutions / arbitary zero padding is not good enough yet for this to be worth the effort. Some of the individual implementations support it now, but the overarching T.nnet.conv.conv2d only supports 'valid' and 'full' modes, so that rules out the metaoptimizer for example. There was discussion about implementing this recently, but as far as I can tell it hasn't happened: Theano/Theano#2118

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

a small preview on alexnet and overfeat:
fbcufft is the winner, and cuda-convnet2 is a close second. cudnn does not even come close, and I am expecting caffe to be in the cudnn territory. Anyways, I will be doing the caffe numbers as well, working on it.
https://gist.github.com/soumith/e6297e93dd2fe3751562

I made a mistake in the first round of cuda-convnet2 numbers. I did not play around with the partialSum setting, it makes a real difference in the gradParameters computation. I use the ideal settings proposed by alex in this file:
https://code.google.com/p/cuda-convnet2/source/browse/layers/layers-imagenet-1gpu.cfg

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

fbcufft's source is here btw: https://github.com/facebook/fbcunn/tree/master/layers/cuda/fft

@stencilman
Copy link

Wow! Looks like we should port this to Theano!

@benanne
Copy link

benanne commented Dec 18, 2014

Awesome! Very happy the FFT approach is still being pursued and it's starting to bear fruit :) It's also cool that it's using CuFFT and doesn't require a custom implementation, which should make it resilient to changes in GPU architectures, just like the Caffe/cudnn approach.

Compared to Theano's Python implementation Is there anything special that was done to make it run fast and efficiently? Or is it just the change from Python to C++ that makes it both faster and more memory efficient? I'm curious :) Is there anywhere I can read about this? (besides the source code that is...)

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

We are releasing a paper on this, stay tuned. It was quite a piece of art to get it right, mostly work done by my colleague Nicholas vasilache

@benanne
Copy link

benanne commented Dec 18, 2014

Cool, would love to see a preprint of that at some point, if you and your colleagues are willing to share :)

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

Yes releasing it on arxiv tomorrow

@benanne
Copy link

benanne commented Dec 18, 2014

Awesome!

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

@benanne not sure if you've noticed, but I changed alexnet to use the One weird trick paper's implementation, rather than the original paper's implementation.
https://github.com/soumith/convnet-benchmarks/blob/master/torch7/imagenet_winners/alexnet.lua

very curious to see theano FFT numbers as well. we wanted to put it in the paper (for alexnet/overfeat), but we dont know how to build nets in theano tsk!

@benanne
Copy link

benanne commented Dec 18, 2014

I hadn't noticed, thanks for pointing that out!

I don't know how to build AlexNet in Theano either, it's not easy to do 'same' convolutions across the different convolution implementations at the moment. The way we do it in Lasagne right now is by performing a full convolution and then slicing the result, but that seems wasteful because many implementations actually support implicit zero padding already (see https://github.com/benanne/Lasagne/blob/master/lasagne/layers/base.py#L635).

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

ok, so our FFT module does not support implicit zero padding yet either. So we add zero padding before hand as a separate op, and then do valid convolutions. the timing difference is negligible between both.

@benanne
Copy link

benanne commented Dec 18, 2014

Ok, then the problem is the 3x3 pooling with stride 2. Someone is working on implementing strided pooling, but it hasn't been merged: Theano/Theano#2196

It's possible to do a custom strided pooling implementation in pure Theano, but that's pretty slow in my experience.

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

aren't you guys a little behind the times ;)
cudnn has the pooling implementation, i thought someone wrote the bindings for them to theano.

@benanne
Copy link

benanne commented Dec 18, 2014

Hey, don't look at me, I've brought this up in the past :) (See Theano/Theano#2118 and Theano/Theano#2196 )

Yeah, the cudnn pooling has been wrapped, but ideally you don't want to use it directly. Rather, you want to use the default pooling functions/classes and rely on Theano's optimizer to swap in alternate implementations, depending on preference.

The problem is that Theano's default pooling operator is not expressive enough (no implicit padding, no strides). Same with the default convolution operator, which only allows full and valid convolutions, no custom padding or same convolutions.

Some implementations support these features, but there is no way to use them if you're things the 'correct' way - only if you directly use the custom implementations instead of relying on the optimizer.

I really like the concept of an optimizing compiler for this kind of code, but it's definitely true that Theano's default interfaces are not keeping up with the new features of some alternative implementations. They're getting stale and we're at a point where it's starting to limit the possibilities.

I can't do without the automatic differentiation though, so you won't be able to convert me just yet :p

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

for what it's worth, all of our 100+ nn ops, (along with another 100+ from the community) have a backprop wrt weights and backprop wrt input. and, we are at a faster pace in terms of development (for example googlenet can be implemented, rnns of all kinds, lstm etc.).
In that aspect neither theano or caffe can touch us.

@benanne
Copy link

benanne commented Dec 18, 2014

Sure, but I'm thinking of stuff like the custom output layer I used for the Galaxy Challenge, to incorporate the constraints on the predicted probabilities (this monster: https://github.com/benanne/kaggle-galaxies/blob/master/custom.py#L576 ). There is no way in hell I'd want to implement that gradient myself. I don't think it's possible to avoid that in any other library for now. Theano's gradient implementations are at the level of individual operations, which is what makes the automatic differentiation so powerful. I guess the gradients implemented in Torch are too coarse-grained for that.

@soumith
Copy link
Owner Author

soumith commented Dec 18, 2014

agreed. we dont have a graph optimizer. i implemented the same (or similar) here: https://github.com/soumith/galaxyzoo/blob/master/2_model.lua#L35
but theano probably optimizes out certain common expressions and makes them efficient.

anyways, the thread is going off track :) i will update when I have the end-to-end benchmarks on caffe, and start a table with nn + cudnn + caffe + ccn2 for ovefeat,alexnet,oxfordnet.

@liuliu
Copy link
Contributor

liuliu commented Dec 18, 2014

Have you used my "two buffer for backprop" trick on GoogLeNet thing :P
Interested to hear if you have any success on reprod GoogLeNet, I probably will start that pretty soon.

@soumith
Copy link
Owner Author

soumith commented Dec 19, 2014

@liuliu did not use your two-buffer trick yet. so many things to try :)

@soumith
Copy link
Owner Author

soumith commented Dec 19, 2014

okay so I added the titan black numbers for alexnet,overfeat,vggmodel-A and in the README.md for ccn2, cudnn Release 2, cudnn Release1, torch-cunn.

@nouiz
Copy link
Contributor

nouiz commented Dec 20, 2014

Merry Christmas and Happy New Year!

(I won't be checking my emails/github frequently until next year, so do not expect fast response)

Thanks for the release of fbcufft that will be useful to many other project I think.

For the cudnn r2 benchmark, I think you should make 2 cases. One without extra memory and one with. As fft and cuda-convnet[2] use/can use extra memory, I think knowing both timing from cuddn r2 would make a better comparison. Which one did you timed?

@nouiz
Copy link
Contributor

nouiz commented Dec 20, 2014

@benanne, I agree having the CPU code support all the version is great. But sometimes we need to go faster for research and implementing just the case we use (CUDA) allow this. But we also continue to make available in the "normal"/CPU interface the new features later.

I must disagree with @soumith. Most Theano op have its grad implement for a long time. I do not think torch have more. In fact, in Theano we can in many case take the second derivate and the Lop/Rop needed for hessian free. I didn't see that in other software that automatize this (but I didn't looked). So I do not think there is more grad in Torch :)

Also, for the speed of development, we had LSTM and RNN before torch I think. For RNN, @nicholas-leonard asked on our lab mailing list question on now to implement it. We had it for a long time. With the same functionality we can implement LSTM. There wasn't new functionality needed. We are in the process to release an example for LSTM base on some old code: lisa-lab/DeepLearningTutorials#55. There is also Alex conv net code that was linked to by a Theano user, I forget where.

For multi-GPU, it was possible with multiple process for a long time with an example. For the case of Multi-GPU in the same process as current Torch, there is a PR that is just waiting for my review :)

@benanne
Copy link

benanne commented Dec 20, 2014

Merry Christmas and a happy new year to everyone as well :)

@nouiz: To be honest, I don't really care about CPU support for every feature either. All the machines I use have GPUs, including my laptop. But the way Theano is built requires using 'more general' ops in your graph that get replaced by optimizations to profit maximally. Having to use specialized ops in a graph directly undoes some of the advantages that 'the Theano way' offers.

And let's face it, for many people the CPU convolution is just a placeholder anyway, they never execute it directly. Whether it has all these features implemented or not wouldn't matter, there just needs to be a way to specify all the parameters so the more specialized implementations can pick them up when inserted during the optimization phase.

I would almost advocate for making the 'default' Theano convolution non-executable: just a placeholder to hold all the parameters, and then the various executable versions would be inserted only at optimization time. Because with the current set-up, where the CPU version is also the placeholder, rapid progress in the GPU implementations seems to be held back by the fact that the new features can't be used in the "normal" Theano way: you have to use the specialized implementations directly if you want those. If you just want to use T.nnet.conv2d, that would require for there to be a matching CPU implementation first.

With the advent of @f0k's meta-optimizers, this has become even more apparent: if you want to take advantage of those, you have to use conv2d in your graph and you are unable to use features like arbitrary zero padding that GpuCorrMM and dnn_conv now offer, for example.

In practice, what we're doing for Lasagne now is implementing several Conv2DLayer classes that each wrap a different implementation, to allow them to be easily swappable. This works, but it feels kind of wrong because the whole idea of Theano is to represent your computation as a graph so you can replace parts easily. So this problem would be much better solved at the Theano level.

Everything I've said also goes for pooling of course (and probably a bunch of other operations). I don't know if making the default ops 'abstract' (i.e. without implementation) is the best solution, but the current set-up does seem to slow down the integration of new features.

That said, despite all my whining and complaining I still think Theano is an incredible piece of software that I've been able to rely on for almost 4 years now. It's certainly made my life a lot easier. So please don't take all of this the wrong way :)

... I feel like I've hijacked this issue twice over now :p Sorry @soumith ! If this point is open to discussion I could create an issue for it at the Theano repo.

@soumith
Copy link
Owner Author

soumith commented Dec 20, 2014

i agree, we need to move the discussion of theano to theano, mainly because the repo is watched by a diverse set of people who might not be interested in a lot of things.
ps: thanks to the CuDNN guys for giving us 3D convolutions in the new release (which dont need extra memory of course), i've interfaced them, and they work great.

@soumith
Copy link
Owner Author

soumith commented Dec 22, 2014

@nouiz For CUDNN R2, I implemented the "FASTEST" mode. Sometimes it uses some extra memory, but the extra memory is in the order of a few kilobytes, not at all significant, so I do not think it is significant.

@soumith
Copy link
Owner Author

soumith commented Dec 25, 2014

@benanne
Copy link

benanne commented Dec 25, 2014

Excellent, thanks for the heads up!

@apark263
Copy link

For L1 the IFM->OFM is 3->96, but in L2, it goes from 64->128, were these numbers selected to intentionally be mismatched for benchmarking purposes? Or should L2 be 96->128?

@soumith
Copy link
Owner Author

soumith commented Dec 26, 2014

@apark263 the numbers were very arbitrarily chosen to what I thought might be a "typical" L2. I've moved to more realistic and real-world "imagenet-winners" benchmarks for exactly this reason, that the layer sizes were arbitrarily chosen.

@dadaism
Copy link
Contributor

dadaism commented Feb 18, 2015

Hi @soumith , it seems that you've fixed the "numFilters assertions" problem of cuda-convnet2. I also have the same problem. Can you give me a hint on how to do that? Also, have you finished Caffe version of these four imagenets? If you haven't and are busy, I think I can help.

@soumith
Copy link
Owner Author

soumith commented Feb 19, 2015

Hey @dadaism. I fixed it in my cuda-convnet2.torch fork by rewriting some of the kernels without texture loading and adding them to the dispatcher for the case when the tensor is bigger than 2^27. You can find some commits doing so (on phone, so can't pull them up now).

I have not gotten time to tackle the imagenet-winners on Caffe side. Any help from you would be very nice. I am working on benchmarking on maxwell architecture.

@andravin
Copy link

Why not use 128-bit texture loads instead? That will extend the reach of the texture indexes, reduce the number of load instructions, and reduce the indexing arithmetic.

@soumith
Copy link
Owner Author

soumith commented Feb 19, 2015

It could be done. And the corner cases can be loaded via a regular load. It would also remove this 512mb limit in ccn2. No one has done it, that's all.

@andravin
Copy link

Why is the VGG benchmark using a mini-batch size of 32? The paper seems to say the mini-batch size was 256/4GPUs = 64/GPU. http://arxiv.org/abs/1409.1556

@soumith
Copy link
Owner Author

soumith commented Feb 22, 2015

Some of the libs don't support in-place relu (without which you won't fit minibatch 64), and I was going to try moving the benchmarks to gtx980, so I made this small trade-ofg

@andravin
Copy link

Mini-batch size can have a drastic affect on performance, so changing it for a benchmark is significant. Would also be useful to know which libraries failed to run full VGG.

Maybe at least document clearly where you deviate from the original experiment and why.

@soumith
Copy link
Owner Author

soumith commented Feb 22, 2015

@andravin I think you make a good point. I will just make it mini-batch size 64. Waiting for a new card for that to happen.

@soumith soumith changed the title benchmarking Imagenet winners [December 2014] benchmarking Imagenet winners Aug 2, 2015
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

8 participants