Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] Generic multi layer perceptron #3204

Conversation

IssamLaradji
Copy link
Contributor

Currently I am implementing layers_coef_ to allow for any number of hidden layers.

This pull request is to implement the generic Multi-layer perceptron as part of the GSoC 2014 proposal.

The expected time to finish this pull request is June 15

The goal is to extend Multi-layer Perceptron to support more than one hidden layer and to support having a pre-training phase (initializing weights through Restricted Boltzmann Machines) and a fine-tuning phase; and write its documentation.

This directly follows from this pull-request: #2120

TODO:

  • replace private attributes initialized in _fit by local variables and pass them as argument to private helper methods to make the code more readable and reduce pickled model size by not storing stuff that is not necessary at prediction time.
  • refactor the _fit method to call into submethods for different algorithms.
  • introduce self.t_ to store SGD learning rate progress and decouple it from self.n_iter_ that should consistently track epochs.
  • issue ConvergenceWarning whenever max_iter is reached when calling fit

@larsmans
Copy link
Member

What's the todo list for this one?

@IssamLaradji
Copy link
Contributor Author

Hi larsmans, the todo list is,

  1. it should support more than one hidden layer; so there would be one generic layer list layer_coef_
  2. it should support weights' initialization using trained Restricted Boltzmann Machines, like the one proposed by Hinton et al. (2006): http://www.cs.toronto.edu/~fritz/absps/ncfast.pdf

@ogrisel
Copy link
Member

ogrisel commented May 27, 2014

For the weight init, I would just use a warm_start=True constructor param and let the user set the layers_coef_ and layers_intercept_ attribute manually as done for other existing models such as SGDClassifier for instance.

@jnothman
Copy link
Member

Out of curiosity, does RBM initialisation mean that fit may be provided
with some unlabelled samples?

On 27 May 2014 19:14, Issam H. Laradji notifications@github.com wrote:

Hi larsmans, the todo list is,

  1. it should support more than one hidden layer; so there would be one
    generic layer list layer_coef_
  2. it should support weights' initialization using trained Restricted
    Boltzmann Machines, like the one proposed by Hinton et al. (2006):
    http://www.cs.toronto.edu/~fritz/absps/ncfast.pdf


Reply to this email directly or view it on GitHubhttps://github.com//pull/3204#issuecomment-44252107
.

@IssamLaradji
Copy link
Contributor Author

@ogrisel should we include another parameter - like unsupervised_weight_init_ - that runs an RBM (or any unsupervised learning algorithm) to initialize the layer weights? I believe warm_start starts training with the previously trained weights but does not necessarily use unsupervised learning algorithm for weight initialization.

@jnothman yes, an RBM trains on the unlabeled samples and its new, trained weights become the initial weights of the corresponding layer in the multi-layer perceptron. The image below shows a basic idea of how this is done.
rbmdeepbeliefnetwork

@larsmans
Copy link
Member

I think we can leave the RBM init to a separate PR.

@IssamLaradji
Copy link
Contributor Author

@larsmans sure thing :)

For the travis build, I believe the error is coming from OrthogonalMatchingPursuitCV, given in line 5442

@ogrisel
Copy link
Member

ogrisel commented May 27, 2014

+1 for leaving the RBM init in a separate PR. Also, no need to couple the 2 models, just extract the weights from a pipeline of RBMs and manually stuck them as layers_coef_ of a MLP with warm_start=True and then call fit with the labels for fine tuning.

For the travis build, I believe the error is coming from OrthogonalMatchingPursuitCV, given in line 5442

Not only: the other builds have failed because the doc tests don't pass either as I told you earlier in the previous PR.

Classifier train-time test-time error-rate
------------------------------------------------------
nystroem_approx_svm 124.819s 0.811s 0.0242
MultilayerPerceptron 359.460s 0.217s 0.0271
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it possible to find hyperparams values to reach better accuracy with tanh activations? It should be possible to go below 2% error rate with a vanilla MLP on mnist.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assumed you intended to have additional unlabelled data, but perhaps
working out the best way to incorporate the unlabelled data into the
fitting procedure (particularly if you support partial_fit) might be a big
question of its own. So I'm +1 for delaying that decision :)

On 27 May 2014 19:43, Olivier Grisel notifications@github.com wrote:

In benchmarks/bench_mnist.py:

+=======================
+
+Benchmark multi-layer perceptron, Extra-Trees, linear svm
+with kernel approximation of RBFSampler and Nystroem
+on the MNIST dataset. The dataset comprises 70,000 samples
+and 784 features. Here, we consider the task of predicting
+10 classes - digits from 0 to 9. The experiment was run in
+a computer with a Desktop Intel Core i7, 3.6 GHZ CPU,
+operating the Windows 7 64-bit version.
+

  • Classification performance:
  • ===========================
  • Classifier train-time test-time error-rate

  • nystroem_approx_svm 124.819s 0.811s 0.0242
  • MultilayerPerceptron 359.460s 0.217s 0.0271

Isn't it possible to find hyperparams values to reach better accuracy with
tanh activations? It should be possible to go below 2% error rate with a
vanilla MLP on mnist.


Reply to this email directly or view it on GitHubhttps://github.com//pull/3204/files#r13069391
.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ogrisel I just made the error rate to 0.017 :)
(fixed an issue with tanh derivative - it didn't pass the gradient test until now)

@jnothman indeed, better to make RBM pipelining a separate PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Glad you found the source of the problem, it's great to have unit tests that check the correctness of the gradient!

@IssamLaradji
Copy link
Contributor Author

Hi guys, I made some major changes.

  1. The algorithm now supports more than one hidden layer by simply putting a list of values in the n_hidden parameter.
    For example, for 3 hidden layers where the first and the second layers have 100 neurons and the 3rd has 50 neurons, the list would be, n_hidden = [100, 100, 50]

  2. I improved the speed of the implementation by more than 25% by removing a redundant loop.

  3. I improved the documentation by making it more comprehensive.

Your feedback will be greatly appreciated. Thank you! :)

@coveralls
Copy link

Coverage Status

Coverage increased (+0.16%) when pulling 2e8dc56 on IssamLaradji:generic-multi-layer-perceptron into daa1dba on scikit-learn:master.

@ogrisel
Copy link
Member

ogrisel commented Jun 10, 2014

@IssamLaradji great work! I will try to review in more details soon. Maybe @jaberg and @kastnerkyle might be interested in reviewing this as well.

Can you please fix the remaining expit related failure under Python 3 w/ recent numpy / scipy?

https://travis-ci.org/scikit-learn/scikit-learn/jobs/27179454#L5790


+ Since hidden layers in MLP make the loss function non-convex
- which contains more than one local minimum, random weights'
initialization could impact the predictive accuracy of a trained model.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would rather say: "meaning that different random initializations of the weight can leading to trained models with varying validation accuracy".

@IssamLaradji
Copy link
Contributor Author

Thanks for the feedback @ogrisel. I improved the documentation more, making it more didactic - especially in the mathematical formulation section.

For the expit related failure under Python 3, I am not sure how to fix the problem since I am using the expit version given in scikit-learn. Isn't the problem within sklearn.utils.fixes?

Thanks.

@kastnerkyle
Copy link
Member

This looks pretty cool so far - I will run some trials on it and try to understand the py3 issues.

Things that would be nice, though maybe not strictly necessary for a first cut PR:

A constructor arg for a custom loss function instead of fixed (maybe it is against the API). Thinking of things like cross-entropy, hinge loss ala Charlie Tang, etc. instead of standard softmax or what have you. It would be nice to have a few default ones available by strings, with the ability to create a custom one if needed.

I like @ogrisel's suggestion for layer_coefs_. It would be useful to run experiments with KMeans networks and also pretraining with autoencoders instead of RBMs. This also opens the door for side packages that can take in weights from other nets (looking at Overfeat, Decaf, Caffe, pylearn2, etc.) and load them into sklearn. This is more a personal interest of mine, but it is nice to see the building blocks there.

It is also plausible that very deep nets are possible to use in feedforward mode on the CPU, even if we can't train them in sklearn directly.

Questions:
I see you have worked on deep autoencoders before - will this framework support that as well? In other words, can layer sizes be different but complimentary? Or are they expected to be a "block" (uniform in size)

I also like the support for other optimizers - it would be sweet to get a hessian free optimizer into scipy, and use it in this general setup. Could make deep-ish NN work somewhat accessible without GPU, though cg is what (I believe) Hinton used for the original DBM/pretraining paper.

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

@IssamLaradji indeed it would be interesting to run a bench of lbfgs vs cg and maybe other optimizers from scipy.optimize, maybe on (a subset of) mnist for instance.

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

We might want to make it possible to use any optimizer from scipy.optimize if the API is homogeneous across all optimizers (I have not checked).

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

@IssamLaradji about the expit pickling issue, it looks like a bug in numpy. I am working on a fix.

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

I submitted a bugfix upstream: numpy/numpy#4800 . If the fix is accepted we might want to backport it in sklearn.utils.fixes.

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

@IssamLaradji actually you can please try to add the ufunc fix to sklearn.utils.exists now to check that it works for us?

Try to add something like:

import pickle

try:
    pickle.loads(pickle.dumps(expit))
except AttributeError:
    # monkeypatch numpy to backport a fix for:
    # https://github.com/numpy/numpy/pull/4800
    import numpy.core
    def _ufunc_reconstruct(module, name):
        mod = __import__(module, fromlist=[name])
        return getattr(mod, name)
    numpy.core._ufunc_reconstruct = _ufunc_reconstruct

@IssamLaradji
Copy link
Contributor Author

Hi @kastnerkyle and @ogrisel , thanks for the reply.

  1. Custom loss function: I could add a parameter to the constructor that accepts strings for selecting the loss function. (In fact, I have done that in my older implementation, but was told to remove it since there weren't enough loss functions)

  2. Pre-training: I could add a pipeline with a placeholder that selects a pre-trainer for the weights. Although I was told to keep that for the next PR, I don't see it as a harm adding an additional constructor parameter and a small method containing the pre-trainer for a quick test :).

  3. Deep Auto-encoder: yes, a sparse autoencoder is a simple adaptation of the feedforward network - I simply need to inject a sparsity parameter into the loss function and its derivatives.

For the layer sizes, they can be different in any way- for example, 1024-512-256-128-64-28, but - like what Hinton said - nothing justifies any set of layer sizes since it depends on the problem instance. Anyhow, this framework can support any set of layer sizes even if they are larger than the number of features.

  1. Selecting scipy optimizers: my older implementation of the vanilla MLP supported all scipy optimizers using the generic scipy minimize method, but there was one problem: it required users to have scipy 13.0+, while scikit-learn requires SciPy (>= 0.7). If we could raise the scipy version requirement, I could easily have this support all scipy optimizers.

Anyhow, L-bfgs is now state-of-the-art optimizer. I tested it against CG and L-BFGS always performed better and faster than CG for several datasets (most other optimizers were unsuitable and did not come any close to CG and l-bfgs as far as speed and accuracy are concerned, but the scipy method also supports custom optimizers which is very useful).

This claim is also justified by Adam Coates and Andrew Ng. here http://cs.stanford.edu/people/ang/?portfolio=on-optimization-methods-for-deep-learning

But I did read that CG can perform better and faster for special kinds of datasets. So I am all for adding the generic scipy optimizer if it wasn't for the minimum version issue. What do you think?

For the ufunc fix, did you mean sklearn.utils.fixes ? because my sklearn version doesn't have sklearn.utils.exists :( . I added the fix to sklearn.utils.fixes and pushed the code to see if it resolves the expit problem.

Thank you.

@coveralls
Copy link

Coverage Status

Coverage increased (+0.16%) when pulling 1d4911b on IssamLaradji:generic-multi-layer-perceptron into daa1dba on scikit-learn:master.


elif 'l-bfgs':
self._backprop_lbfgs(
X, y, n_samples)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put method calls on one line when they fit in 80 columns:

             self._backprop_lbfgs(X, y, n_samples)

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

About the optimizers, thanks for the reference comparing lbfgs and CG. We could add support for arbitrary scipy optimizer and raise a RuntimeException of the version of scipy is too low (with an informative error message) while still using fmin_l_bfgs_o directly by default so that we keep the backward compat fo old versions of scipy by default.

@ogrisel
Copy link
Member

ogrisel commented Jun 11, 2014

It would be great to add squared_hinge and hinge loss functions. But in another PR.

I would also consider pre-training and sparse penalties for autoencoders for separate PRs.

@amueller
Copy link
Member

amueller commented Dec 5, 2014

The gradients are fine, I think. I forgot to shuffle mnist :-/ Now it looks good.
Maybe we want to set shuffle=True by default? It is so cheap compared to the backprob

@amueller amueller mentioned this pull request Dec 5, 2014
22 tasks
@amueller
Copy link
Member

amueller commented Dec 5, 2014

@IssamLaradji That was the place I meant. Sorry, I don't understand your explanation. Would the behavior of the code change if you discarded the return value of _compute_cost_grad?

@IssamLaradji
Copy link
Contributor Author

@amueller oh I thought you meant something else.

It wouldn't change the behavior. I could discard the return value and the "left-hand side" of the equation coef_grads, intercept_grads =_compute_cost_grad(...), and the results will remain the same.

Also, +1 for setting shuffle=True as default.

@amueller
Copy link
Member

Training time for the bench_mnist.py is twice as high on my box than what you gave, but only for the MLP. the others have comparable speed. Could you try to run again with the current parameters and see if it is still the same for you? How many cores do you have?

@IssamLaradji
Copy link
Contributor Author

Strange, I ran it again now and I got,

Classifier                         train-time           test-time                   error-rate   
----------------------------------------------------------------------------------------------
MultilayerPerceptron                 364.75999999        0.088                      0.0178     

which is half the original training time. Are you training using lbfgs or sgd ? lbfgs tends to converge faster.

My machine is equiped with 8 GB RAM and Intel® Core™ i7-2630QM Processor (6M Cache, 2.00 GHz) .

@ogrisel
Copy link
Member

ogrisel commented Dec 18, 2014

lbfgs tends to converge faster.

In my PR against @amueller branch with the enhanced "constant" learning rate and momentum SGD seems to be faster than LBGFS although I have not ploted the "validation score vs epoch" curve as we have no way to do so at the moment.

@amueller
Copy link
Member

I ran exactly the same code, so lbfgs. I thin we should definitely do SGD as it should be much faster on mnist.

@joelkuiper
Copy link

👍

@jeff-mettel
Copy link
Contributor

Excellent work to all, and an exciting feature to be added to sklearn!

I have been looking forward to this functionality to a while - Does it appear likely, or has momentum dissipated?

@amueller
Copy link
Member

amueller commented Apr 7, 2015

It will definitely be merged, and soon.

@jeff-mettel
Copy link
Contributor

@amueller - That's fantastic news, I'm very much looking forward to it. Great work as always!

@naught101
Copy link

Using this a bit at the moment. Looks nice. Some notes:

  • Currently if y in MultilayerPerceptronRegressor.fit() is a vector (dimensions (n,)), .predict() returns a 2d array with dimensions (n,1). Other regressors just return a vector in the same format as y.
  • That's a really long class name. Could it be MLPRegressor instead, similar to SGDRegressor? That abbreviation is common enough, I think (on the wikipedia disambiguation page, comes first in google search for 'MLP learning', and I don't think people will get it confused with My Little Pony)

@amueller
Copy link
Member

amueller commented May 1, 2015

@naught101 It is a long name... maybe we should use MLP. Can you check if the shape is still wrong in #3939?

@naught101
Copy link

@amueller: Yes, the shape is still wrong.

@amueller
Copy link
Member

amueller commented May 4, 2015

Huh, wonder why the common tests no complain.

@amueller
Copy link
Member

amueller commented May 4, 2015

Thanks for checking.

@amueller
Copy link
Member

Merged via #5214

@amueller amueller closed this Oct 23, 2015
@IssamLaradji
Copy link
Contributor Author

Waw!! That's fantastic!! :) :) Great work team!

@naught101
Copy link

Thank you to everyone who worked on this. It will be really useful.

@pasky
Copy link

pasky commented Oct 24, 2015 via email

@jnothman
Copy link
Member

Waw!! That's fantastic!! :) :) Great work team!

Yes, aren't sprints amazing from the outside? Dormant threads are suddenly
marked merged and that project you'd been trying to complete forever is now
off your todo list and you're ready to book a holiday...

Thank you to all the sprinters from those of us on the outside, it's been a
good one!

On 24 October 2015 at 20:04, Petr Baudis notifications@github.com wrote:

Yes, thank you very much! I've been waiting for this for a long time.
(And sorry that I never ended up making good on my offer to help.)


Reply to this email directly or view it on GitHub
#3204 (comment)
.

@IssamLaradji
Copy link
Contributor Author

@jnothman indeed! it's a great surprise to see it merged as I felt that this would stay dormant for much longer time.

Thanks a lot for your great reviews and effort team!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet