[MRG] Multi-layer perceptron (MLP) #2120

Closed
wants to merge 38 commits into
from

Conversation

Projects
None yet
@IssamLaradji
Contributor

IssamLaradji commented Jun 30, 2013

Multi-layer perceptron (MLP)

PR closed in favor or #3204

mlp

This is an extention to larsmans code.

A multilayer perceptron (MLP) is a feedforward artificial neural network model that tries to learn a function f(X)=y where y is the output and X is the input. An MLP consists of multiple layers, usually of one hidden layer, an input layer and an output layer, where each layer is fully connected to the next one. This is a classic algorithm that has been extensively used in Neural Networks.

Code Check out :

  1. git clone https://github.com/scikit-learn/scikit-learn
  2. cd scikit-learn/
  3. git fetch origin refs#2120/head:mlp
  4. git checkout mlp

Tutorial link:

- http://easymachinelearning.blogspot.com/p/multi-layer-perceptron-tutorial.html

Sample Benchmark:

- `MLP` on the scikit's `Digits` dataset gives, - Score for `tanh-based sgd`: 0.981 - Score for `logistic-based sgd`: 0.987 - Score for `tanh-based l-bfgs`: 0.994 - Score for `logistic-based l-bfgs`: 1.000

TODO:

- Review
sklearn/neural_network/mlp.py
+from ..utils.extmath import logsumexp, safe_sparse_dot
+
+
+def validate_grad(J, theta, n_slice):

This comment has been minimized.

@amueller

amueller Jul 1, 2013

Member

you could move this into the test file.

@amueller

amueller Jul 1, 2013

Member

you could move this into the test file.

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

I was wondering where to put it. Thanks!

@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

I was wondering where to put it. Thanks!

sklearn/neural_network/mlp.py
+
+ """
+
+ def __init__(

This comment has been minimized.

@amueller

amueller Jul 1, 2013

Member

n_hidden needs a default value. Maybe 100, just for good measure? Or the geometric / arithmetic mean of n_features and n_classes?

@amueller

amueller Jul 1, 2013

Member

n_hidden needs a default value. Maybe 100, just for good measure? Or the geometric / arithmetic mean of n_features and n_classes?

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

Sure. I'll try both, and see which one excels better.

@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

Sure. I'll try both, and see which one excels better.

This comment has been minimized.

@larsmans

larsmans Jul 3, 2013

Member

Or sqrt(n_features). That's simple and doesn't grow too much with large numbers of features, so you don't immediately fill the swap when n_features is very large.

@larsmans

larsmans Jul 3, 2013

Member

Or sqrt(n_features). That's simple and doesn't grow too much with large numbers of features, so you don't immediately fill the swap when n_features is very large.

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 3, 2013

Contributor

That's smart! Thanks.

@IssamLaradji

IssamLaradji Jul 3, 2013

Contributor

That's smart! Thanks.

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 1, 2013

Member

Could you please say in how far this extends @larsmans PR? This one seems to be completely in Python, while I remember @larsmans's to be in Cython, right? I'm not completely sure how large the benefit of Cython was, though.
Does this one support sparse matrices? cc @temporaer

Member

amueller commented Jul 1, 2013

Could you please say in how far this extends @larsmans PR? This one seems to be completely in Python, while I remember @larsmans's to be in Cython, right? I'm not completely sure how large the benefit of Cython was, though.
Does this one support sparse matrices? cc @temporaer

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor
  1. @amueller, larsmans missing part was the backpropagation which was partly developed in Cython, I developed that part using vectorized matrix operations, its quite fast, for example, running the algorithm on the 'digits' dataset for 400 iterations with 50 hidden neurons, would take about 5 seconds. I also added the option of using of a secondary optimization algorithm 'fmin_l_bfgs_bf' which is as fast, but achieves better classification performance with the same number of iterations. I'm also thinking of adding a third option: 'fmin_cg'. These optimizers are somewhat heavily used in neural networks.

I read that Cython code is easy to produce (just a matter of adding some prefixes and compiling the code). I will Cython the code and see if it adds benefits.

  1. Yes. It supports sparse matrices (via safe_sparse_dot)

Thanks for the review

Contributor

IssamLaradji commented Jul 1, 2013

  1. @amueller, larsmans missing part was the backpropagation which was partly developed in Cython, I developed that part using vectorized matrix operations, its quite fast, for example, running the algorithm on the 'digits' dataset for 400 iterations with 50 hidden neurons, would take about 5 seconds. I also added the option of using of a secondary optimization algorithm 'fmin_l_bfgs_bf' which is as fast, but achieves better classification performance with the same number of iterations. I'm also thinking of adding a third option: 'fmin_cg'. These optimizers are somewhat heavily used in neural networks.

I read that Cython code is easy to produce (just a matter of adding some prefixes and compiling the code). I will Cython the code and see if it adds benefits.

  1. Yes. It supports sparse matrices (via safe_sparse_dot)

Thanks for the review

sklearn/neural_network/mlp.py
+ activation: string, optional
+ Activation function for the hidden layer; either "sigmoid" for
+ 1 / (1 + exp(x)), or "tanh" for the hyperbolic tangent.
+ _lambda : float, optional

This comment has been minimized.

@larsmans

larsmans Jul 1, 2013

Member

We call this alpha almost everywhere else. Don't use a leading underscore on a parameter.

@larsmans

larsmans Jul 1, 2013

Member

We call this alpha almost everywhere else. Don't use a leading underscore on a parameter.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jul 1, 2013

Member

I found that switching to Cython gave about an order of magnitude improvement over pure Python. We can merge this version as an intermediate, it looks every clean. How fast is it on 20newsgroups w/ 100 hidden units?

Member

larsmans commented Jul 1, 2013

I found that switching to Cython gave about an order of magnitude improvement over pure Python. We can merge this version as an intermediate, it looks every clean. How fast is it on 20newsgroups w/ 100 hidden units?

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 1, 2013

Member

Really? Why? In the other implementation, there was no gain at all. I guess you used one sample "mini-batches"?

Member

amueller commented Jul 1, 2013

Really? Why? In the other implementation, there was no gain at all. I guess you used one sample "mini-batches"?

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jul 1, 2013

Member

No, my implementation would take large batches, divide these into randomized minibatches internally (of user-specified size), then train on those. That gave much faster convergence, without the need to actually materialize the minibatches (no NumPy indexing).

Member

larsmans commented Jul 1, 2013

No, my implementation would take large batches, divide these into randomized minibatches internally (of user-specified size), then train on those. That gave much faster convergence, without the need to actually materialize the minibatches (no NumPy indexing).

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 1, 2013

Member

Ok, that makes sense and explains why the cython version is much faster.

Member

amueller commented Jul 1, 2013

Ok, that makes sense and explains why the cython version is much faster.

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

@larsmans, using the whole 20 categories of 20news (not the watered down version) modeled by tf-idf scikit vectorizer, yielding an Input matrix of 18828 rows and 74324 columns aka features, and with 100 hidden neurons, the algorithm fitted on the whole sparse matrix with around 1 second per iteration. It seems like a good enough speed for MLP for such large data, but I might be wrong.

Contributor

IssamLaradji commented Jul 1, 2013

@larsmans, using the whole 20 categories of 20news (not the watered down version) modeled by tf-idf scikit vectorizer, yielding an Input matrix of 18828 rows and 74324 columns aka features, and with 100 hidden neurons, the algorithm fitted on the whole sparse matrix with around 1 second per iteration. It seems like a good enough speed for MLP for such large data, but I might be wrong.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jul 1, 2013

Member

What's the F1 score, and how many iterations are needed for it? (I got somewhat faster results, but I wonder if LBFGS converges faster than SGD.)

Member

larsmans commented Jul 1, 2013

What's the F1 score, and how many iterations are needed for it? (I got somewhat faster results, but I wonder if LBFGS converges faster than SGD.)

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 1, 2013

Contributor

Right now, I applied the code on 4 categories of the 20news corpus, with 100 iterations and 100 hidden neurons, l_bfgs achieved an average f1-score of 0.87. I might need to leave the code run for a long time before it converges (it doesn't converge even after 300 iteration), thus, I suspect there is a bug in my code.

In your pull request you mentioned that you tested your code on a small subset of 20news corpus achieving similar results, did you use 4 categories too?

Contributor

IssamLaradji commented Jul 1, 2013

Right now, I applied the code on 4 categories of the 20news corpus, with 100 iterations and 100 hidden neurons, l_bfgs achieved an average f1-score of 0.87. I might need to leave the code run for a long time before it converges (it doesn't converge even after 300 iteration), thus, I suspect there is a bug in my code.

In your pull request you mentioned that you tested your code on a small subset of 20news corpus achieving similar results, did you use 4 categories too?

sklearn/neural_network/mlp.py
+ inds = np.arange(n_samples)
+ rng.shuffle(inds)
+ n_batches = int(np.ceil(len(inds) / float(self.batch_size)))
+ # Transpose improves performance (from 0.5 seconds to 0.05)

This comment has been minimized.

@larsmans

larsmans Jul 2, 2013

Member

Improves the performance of what? For dense or sparse data?

@larsmans

larsmans Jul 2, 2013

Member

Improves the performance of what? For dense or sparse data?

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 2, 2013

Contributor

The performance improved in calculating the cost and the gradient. It was observed on dense data, didn't try it on sparse yet.

It might look peculiar, but it has something to do with the matrix multiplications. I just played with the timeit library to understand the performance increase. I found that if for example you multiply matrices A and B, together, while assuming the time it takes is 0.25 ms for such multiplication, then multiplying B.T with A.T could take twice as long, that is 0.5 ms. So, that small time increase will add up if the cost and gradient is calculated multiple times. In other words, multiplying matrices with different shapes could incur time overheads.

I will commit the non-transposed function to benchmark again, just to be safe.

@IssamLaradji

IssamLaradji Jul 2, 2013

Contributor

The performance improved in calculating the cost and the gradient. It was observed on dense data, didn't try it on sparse yet.

It might look peculiar, but it has something to do with the matrix multiplications. I just played with the timeit library to understand the performance increase. I found that if for example you multiply matrices A and B, together, while assuming the time it takes is 0.25 ms for such multiplication, then multiplying B.T with A.T could take twice as long, that is 0.5 ms. So, that small time increase will add up if the cost and gradient is calculated multiple times. In other words, multiplying matrices with different shapes could incur time overheads.

I will commit the non-transposed function to benchmark again, just to be safe.

This comment has been minimized.

@larsmans

larsmans Jul 2, 2013

Member

It's not too surprising; probably due to Fortran vs. C arrays (column-major or row-major). Input to np.dot should ideally be a C and a Fortran array, in that order, IIRC.

@larsmans

larsmans Jul 2, 2013

Member

It's not too surprising; probably due to Fortran vs. C arrays (column-major or row-major). Input to np.dot should ideally be a C and a Fortran array, in that order, IIRC.

This comment has been minimized.

@larsmans

larsmans Jul 2, 2013

Member

But anyway, my point was that performance figures out of context don't belong in code :)

@larsmans

larsmans Jul 2, 2013

Member

But anyway, my point was that performance figures out of context don't belong in code :)

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 2, 2013

Contributor

Thanks! I'll have such comments removed soon :)

@IssamLaradji

IssamLaradji Jul 2, 2013

Contributor

Thanks! I'll have such comments removed soon :)

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jul 3, 2013

Member

I have some MLP documentation lying around, I'll see if I can dig that up, too.

Member

larsmans commented Jul 3, 2013

I have some MLP documentation lying around, I'll see if I can dig that up, too.

sklearn/neural_network/mlp.py
@@ -0,0 +1,617 @@
+"""Mulit-layer perceptron

This comment has been minimized.

@larsmans

larsmans Jul 3, 2013

Member

Typo

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 5, 2013

Contributor

Sorry for being slow in responding, I had a bug in the code which took time to fix because the transposed X was confusing everything :). I had a weird benchmark that made me think that X.T improved performance, but in reality it did not, so I removed the transpose, making the code cleaner and easier to debug while the performance unchanged.

Moreover, I just committed a lot of changes, including,

  • Optimization_method parameter for selecting any scipy optimizer
  • Support of SGD
  • Improved minibatch creation using scikit's gen_even_slices
    • (much faster than X[inds[minibatch::n_batches]])
  • Support of loss functions cross-entropy and square (more will be added)
  • Typos and name fixes

The performance benchmark on the digits dataset (100 hidden neurons and 170 iterations),

  • SGD with cross-entropy loss
    • Score : 0.95
  • Optimization using CG aka Congruent Gradient with cross-entropy loss
    • Score : 0.95
  • Optimization using l-bfgs with square loss
    • Score : 0.98 (it has converged in 80 iterations)
  • Please note that the score is worse when the loss is square for SGD and CG.

Will post the test results on the 20News dataset soon.

Some of the remaining TODO's would be:

  • Use sqrt(n_features) to select the number of hidden neurons
  • Update the documentation
  • Add verbose
  • Add a test file
  • Add an example file

Thank you for your great reviews!

Contributor

IssamLaradji commented Jul 5, 2013

Sorry for being slow in responding, I had a bug in the code which took time to fix because the transposed X was confusing everything :). I had a weird benchmark that made me think that X.T improved performance, but in reality it did not, so I removed the transpose, making the code cleaner and easier to debug while the performance unchanged.

Moreover, I just committed a lot of changes, including,

  • Optimization_method parameter for selecting any scipy optimizer
  • Support of SGD
  • Improved minibatch creation using scikit's gen_even_slices
    • (much faster than X[inds[minibatch::n_batches]])
  • Support of loss functions cross-entropy and square (more will be added)
  • Typos and name fixes

The performance benchmark on the digits dataset (100 hidden neurons and 170 iterations),

  • SGD with cross-entropy loss
    • Score : 0.95
  • Optimization using CG aka Congruent Gradient with cross-entropy loss
    • Score : 0.95
  • Optimization using l-bfgs with square loss
    • Score : 0.98 (it has converged in 80 iterations)
  • Please note that the score is worse when the loss is square for SGD and CG.

Will post the test results on the 20News dataset soon.

Some of the remaining TODO's would be:

  • Use sqrt(n_features) to select the number of hidden neurons
  • Update the documentation
  • Add verbose
  • Add a test file
  • Add an example file

Thank you for your great reviews!

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Jul 5, 2013

Member

You can get a unit test by

git remote add larsmans git@github.com:larsmans/scikit-learn.git
git fetch larsmans
git cherry-pick 3176ce315b38176484fd57357496f8a6a0589071

I'd send you a PR, but it looks like the GitHub UI changes broke PRs between forks.

Member

larsmans commented Jul 5, 2013

You can get a unit test by

git remote add larsmans git@github.com:larsmans/scikit-learn.git
git fetch larsmans
git cherry-pick 3176ce315b38176484fd57357496f8a6a0589071

I'd send you a PR, but it looks like the GitHub UI changes broke PRs between forks.

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 6, 2013

Contributor

@larsmans the unit test is very useful! I will renovate the code as per the comments.

Thanks for the review!

Contributor

IssamLaradji commented Jul 6, 2013

@larsmans the unit test is very useful! I will renovate the code as per the comments.

Thanks for the review!

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Jul 6, 2013

Contributor

Updates

- Replaced scipy `minimize` with`l-bfgs` - So not to compel users to install scipy 13.0+ - Renames, as per the comments, done - Divided the function to `MLPClassifier` and `MLPRegressor` - Set `square_loss` default for `MLPRegressor` - Set `log` default for `MLPClassifier` - Fixed long lines (some lines might still be long) - Added learning rates, that include, `optimal`, `constant`, and `invscaling` - New Benchmark on the `digits` dataset (100 iterations, 100 hidden neurons, `log` loss) - `tanh-based SGD` : `0.957` - `tanh-based l_bfgs` : `0.985` - `logistic-based SGD` : `0.992` - `logistic-based l_bfgs` : `1.000` (converged in `70` iterations)

These are interesting results because tanh should make the algorithm converge faster than logistic. I suspect a bug in computing the deltas (line 454 to 466) that lead to these obscure results. I'll use the unit test to ensure the backpropagation is working as necessary.

The documentation will be updated once the code is deemed reliable.

Thank you for your great reviews and tips! :)

Contributor

IssamLaradji commented Jul 6, 2013

Updates

- Replaced scipy `minimize` with`l-bfgs` - So not to compel users to install scipy 13.0+ - Renames, as per the comments, done - Divided the function to `MLPClassifier` and `MLPRegressor` - Set `square_loss` default for `MLPRegressor` - Set `log` default for `MLPClassifier` - Fixed long lines (some lines might still be long) - Added learning rates, that include, `optimal`, `constant`, and `invscaling` - New Benchmark on the `digits` dataset (100 iterations, 100 hidden neurons, `log` loss) - `tanh-based SGD` : `0.957` - `tanh-based l_bfgs` : `0.985` - `logistic-based SGD` : `0.992` - `logistic-based l_bfgs` : `1.000` (converged in `70` iterations)

These are interesting results because tanh should make the algorithm converge faster than logistic. I suspect a bug in computing the deltas (line 454 to 466) that lead to these obscure results. I'll use the unit test to ensure the backpropagation is working as necessary.

The documentation will be updated once the code is deemed reliable.

Thank you for your great reviews and tips! :)

sklearn/neural_network/mlp.py
+ -------
+ x_new: array-like, shape (M, N)
+ """
+ X *= -X

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Jul 11, 2013

Member

You need to indicate in a very visible way in the docstring that the input data is modified.

@GaelVaroquaux

GaelVaroquaux Jul 11, 2013

Member

You need to indicate in a very visible way in the docstring that the input data is modified.

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 11, 2013

Contributor

@GaelVaroquaux , so it would be better to write Returns the value computed by the derivative of the hyperbolic tan function instead of Computes the derivative of the hyperbolic tan function in line 100?

or Modifies the input 'x' via the computation of the tanh derivative ?

thanks

@IssamLaradji

IssamLaradji Jul 11, 2013

Contributor

@GaelVaroquaux , so it would be better to write Returns the value computed by the derivative of the hyperbolic tan function instead of Computes the derivative of the hyperbolic tan function in line 100?

or Modifies the input 'x' via the computation of the tanh derivative ?

thanks

This comment has been minimized.

@larsmans

larsmans Jul 12, 2013

Member

If you make all of these private (preprend an _ to the name) then a single comment above them would be enough, I think.

@larsmans

larsmans Jul 12, 2013

Member

If you make all of these private (preprend an _ to the name) then a single comment above them would be enough, I think.

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 12, 2013

Contributor

Thanks, fixed.
On another note, is there a way to beat the travis build? Does it fail because the test cases are not established yet? Thanks

@IssamLaradji

IssamLaradji Jul 12, 2013

Contributor

Thanks, fixed.
On another note, is there a way to beat the travis build? Does it fail because the test cases are not established yet? Thanks

This comment has been minimized.

@larsmans

larsmans Jul 12, 2013

Member

Travis runs a bunch of tests on the entire package, including your code because it detects a class inheriting from ClassifierMixin. You should inspect the results to see what's going wrong.

@larsmans

larsmans Jul 12, 2013

Member

Travis runs a bunch of tests on the entire package, including your code because it detects a class inheriting from ClassifierMixin. You should inspect the results to see what's going wrong.

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 12, 2013

Contributor

Thanks, the errors are very clear now :)

@IssamLaradji

IssamLaradji Jul 12, 2013

Contributor

Thanks, the errors are very clear now :)

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Jul 11, 2013

Member

The neural_network sub-package needs to be added to the setup.py of sklearn. Elsewhere it does not get copied during the install.

Member

GaelVaroquaux commented Jul 11, 2013

The neural_network sub-package needs to be added to the setup.py of sklearn. Elsewhere it does not get copied during the install.

sklearn/neural_network/mlp.py
+ }
+ self.activation = activation_functions[activation]
+ self.derivative = derivative_functions[activation]
+ self.output_func = activation_functions[output_func]

This comment has been minimized.

@luoq

luoq Jul 11, 2013

parameters of `init' should not be changed. Same for random_state below. See http://scikit-learn.org/stable/developers/index.html#apis-of-scikit-learn-objects

@luoq

luoq Jul 11, 2013

parameters of `init' should not be changed. Same for random_state below. See http://scikit-learn.org/stable/developers/index.html#apis-of-scikit-learn-objects

This comment has been minimized.

@IssamLaradji

IssamLaradji Jul 11, 2013

Contributor

@luoq thanks for pointing that out, I fixed the initialization disagreements and pushed the new code.

~Issam

@IssamLaradji

IssamLaradji Jul 11, 2013

Contributor

@luoq thanks for pointing that out, I fixed the initialization disagreements and pushed the new code.

~Issam

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

Hi everyone, sorry for being inactive in this, it's been a laborious 2 weeks :). I have updated the code by improving the documentation and eliminating tanh related problems. As tanh can yield negative values, applying the log function on that produces an error, so I added code that scales the value in [0,1] range to ensure it being positive via 0.5*(a_output + 1)

The code seems to be accepted by the travis build, however, MLPRegressor is yet to be implemented, but will be done soon.

PS: I'm also writing a blog that aims in helping 'newcomers to the field' (maybe practitionars even) engage in Neural Networks
http://easydeeplearning.blogspot.com/p/multi-layer-perceptron-tutorial.html

Thanks in advance!

Contributor

IssamLaradji commented Aug 2, 2013

Hi everyone, sorry for being inactive in this, it's been a laborious 2 weeks :). I have updated the code by improving the documentation and eliminating tanh related problems. As tanh can yield negative values, applying the log function on that produces an error, so I added code that scales the value in [0,1] range to ensure it being positive via 0.5*(a_output + 1)

The code seems to be accepted by the travis build, however, MLPRegressor is yet to be implemented, but will be done soon.

PS: I'm also writing a blog that aims in helping 'newcomers to the field' (maybe practitionars even) engage in Neural Networks
http://easydeeplearning.blogspot.com/p/multi-layer-perceptron-tutorial.html

Thanks in advance!

@arjoly

This comment has been minimized.

Show comment
Hide comment
@arjoly

arjoly Aug 2, 2013

Member

Instead of using an abreviation for MLP, why not write plainly MultilayerPerceptron? Thus you would get more readable class name MultilayerPerceptronClassifier, MultilayerPerceptronRegressor and BaseMultilayerPerceptron.

Member

arjoly commented Aug 2, 2013

Instead of using an abreviation for MLP, why not write plainly MultilayerPerceptron? Thus you would get more readable class name MultilayerPerceptronClassifier, MultilayerPerceptronRegressor and BaseMultilayerPerceptron.

@larsmans

This comment has been minimized.

Show comment
Hide comment
@larsmans

larsmans Aug 2, 2013

Member

Also, can you rebase on master? We merged RBMs, so there's a neural_network module now.

Member

larsmans commented Aug 2, 2013

Also, can you rebase on master? We merged RBMs, so there's a neural_network module now.

sklearn/neural_network/mlp.py
+from itertools import cycle, izip
+
+
+def _logistic(x):

This comment has been minimized.

@larsmans

larsmans Aug 2, 2013

Member

In master, there's a fast logistic function in sklearn.utils.extmath.

@larsmans

larsmans Aug 2, 2013

Member

In master, there's a fast logistic function in sklearn.utils.extmath.

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

Thanks, I will plug it in!

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

Thanks, I will plug it in!

sklearn/neural_network/mlp.py
+ cost = np.sum(diff**2)/ (2 * n_samples)
+ elif self.loss == 'log':
+ # To avoid math error, tanh values are re-scaled
+ if self.activation == 'tanh': a_output = 0.5*(a_output + 1)

This comment has been minimized.

@larsmans

larsmans Aug 2, 2013

Member

The activation function should be used at the hidden layer. The output activation function should be either linear or softmax.

@larsmans

larsmans Aug 2, 2013

Member

The activation function should be used at the hidden layer. The output activation function should be either linear or softmax.

This comment has been minimized.

@GaelVaroquaux

GaelVaroquaux Aug 2, 2013

Member

Also a trivial style remark: even thought it is valid Python, we prefer when 'if ... :' is followed by a line return and a one-line block.

@GaelVaroquaux

GaelVaroquaux Aug 2, 2013

Member

Also a trivial style remark: even thought it is valid Python, we prefer when 'if ... :' is followed by a line return and a one-line block.

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

@larsmans, thanks a million! that was my blind spot in MLP. However one concern, I presume linear output activation is used in the case of binary classification since a score, obtained using the decision_function, that is above 0 is rounded to class 1 and below zero is rounded to class -1 which means its necessary that class 0 is labeled as -1 and the other as 1. But, if there is a negative or zero value output, how would the log in cross-entropy work without causing math error? Because the loss function involves Ti log(ai) + (1-Ti)log(1-ai) where ai is the output and Ti is the actual labels. However, ai can be negative rendering the function invalid. Should there be some kind of scaling as a workaround?

Thanks @GaelVaroquaux , I will update the code as per your suggestion.

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

@larsmans, thanks a million! that was my blind spot in MLP. However one concern, I presume linear output activation is used in the case of binary classification since a score, obtained using the decision_function, that is above 0 is rounded to class 1 and below zero is rounded to class -1 which means its necessary that class 0 is labeled as -1 and the other as 1. But, if there is a negative or zero value output, how would the log in cross-entropy work without causing math error? Because the loss function involves Ti log(ai) + (1-Ti)log(1-ai) where ai is the output and Ti is the actual labels. However, ai can be negative rendering the function invalid. Should there be some kind of scaling as a workaround?

Thanks @GaelVaroquaux , I will update the code as per your suggestion.

This comment has been minimized.

@larsmans

larsmans Aug 3, 2013

Member

The cross-entropy loss should be computed on the softmax, not on the decision function. Effectively, an MLP is logistic regression with a hidden layer, so check any reference for LR (e.g. Bishop, p. 209, if you have a copy of that).

@larsmans

larsmans Aug 3, 2013

Member

The cross-entropy loss should be computed on the softmax, not on the decision function. Effectively, an MLP is logistic regression with a hidden layer, so check any reference for LR (e.g. Bishop, p. 209, if you have a copy of that).

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 3, 2013

Contributor

Hi @larsmans , thanks again. I have the book, I also found 3 loss functions for binary classification in this site : http://fa.bianp.net/blog/tag/loss-function.html

They are,

  1. zero_one: zero_one_loss(y_true, y_pred)
  2. hinge : max(0, 1-(y_pred * y_true))
  3. logistic_log: log(1+exp(y_pred * y_true)) or log which is adjusted to cross-entropy if multi-class is the problem at hand

Edit: Just found out that sgd_fast in scikit implements all the required loss functions :) :)

@IssamLaradji

IssamLaradji Aug 3, 2013

Contributor

Hi @larsmans , thanks again. I have the book, I also found 3 loss functions for binary classification in this site : http://fa.bianp.net/blog/tag/loss-function.html

They are,

  1. zero_one: zero_one_loss(y_true, y_pred)
  2. hinge : max(0, 1-(y_pred * y_true))
  3. logistic_log: log(1+exp(y_pred * y_true)) or log which is adjusted to cross-entropy if multi-class is the problem at hand

Edit: Just found out that sgd_fast in scikit implements all the required loss functions :) :)

sklearn/neural_network/mlp.py
+ n_classes: int
+ Number of target classes
+ """
+ return np.hstack((W1.ravel(), W2.ravel(),

This comment has been minimized.

@larsmans

larsmans Aug 2, 2013

Member

hstack makes a copy of all these (rather large) matrices. There must be a smarter way of packing and unpacking parameters.

@larsmans

larsmans Aug 2, 2013

Member

hstack makes a copy of all these (rather large) matrices. There must be a smarter way of packing and unpacking parameters.

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

Maybe I could add a pre-initialized weight as a parameter of the _pack method that takes in the raveled weights? So that .hstack would not need to be used

Would that be more efficient?

@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

Maybe I could add a pre-initialized weight as a parameter of the _pack method that takes in the raveled weights? So that .hstack would not need to be used

Would that be more efficient?

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Aug 2, 2013

Contributor

@arjoly, the naming is a good idea thanks!
@larsmans sure I will have it rebased

Contributor

IssamLaradji commented Aug 2, 2013

@arjoly, the naming is a good idea thanks!
@larsmans sure I will have it rebased

@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji Aug 4, 2013

Contributor

Okay done :), I have fixed the binary classification, I'm getting 100% score with logistic as well as tanh on a binary dataset generated using the Digits scikit's repository. It turns out that I had to apply logistic on the output layer regardless of the activation function in the hidden layer, and the loss function is
-np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

Gladly, it passed the travis test, now what is left is to re-use some of scikit's cython-based loss functions (and logistic) for improved speed and implement MLP for regression.

In addition, the packing and unpacking methods are to be improved.

Contributor

IssamLaradji commented Aug 4, 2013

Okay done :), I have fixed the binary classification, I'm getting 100% score with logistic as well as tanh on a binary dataset generated using the Digits scikit's repository. It turns out that I had to apply logistic on the output layer regardless of the activation function in the hidden layer, and the loss function is
-np.mean(y_true * np.log(y_pred) + (1 - y_true) * np.log(1 - y_pred))

Gladly, it passed the travis test, now what is left is to re-use some of scikit's cython-based loss functions (and logistic) for improved speed and implement MLP for regression.

In addition, the packing and unpacking methods are to be improved.

examples/neural_network/mlp_example.py
+ random_state=1).fit(X, y)
+ print(
+ "training accuracy for %s-based %s: %f" %
+ (activation, algorithm, clf.score(X, y)))

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

It is not clear what you want to show.
Have a look at the other example. Usually, examples are used to show some particularities of the underlying estimator and figures are re-used in the narrative documentation.

@arjoly

arjoly Aug 4, 2013

Member

It is not clear what you want to show.
Have a look at the other example. Usually, examples are used to show some particularities of the underlying estimator and figures are re-used in the narrative documentation.

examples/neural_network/mlp_example.py
+ clf = MultilayerPerceptronClassifier(
+ algorithm=algorithm,
+ activation=activation,
+ random_state=1).fit(X, y)

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

Can you use a cross validation strategy to assess the performance?

@arjoly

arjoly Aug 4, 2013

Member

Can you use a cross validation strategy to assess the performance?

sklearn/neural_network/mlp.py
+
+# Author: Issam Laradji <issam.laradji@gmail.com>
+# Credits to: Amueller and Larsmans codes on MLP
+

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

Can you add the license?

# Licence: BSD 3 clause
@arjoly

arjoly Aug 4, 2013

Member

Can you add the license?

# Licence: BSD 3 clause

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

done :)

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

done :)

sklearn/neural_network/mlp.py
+ -------
+ x_new: array-like, shape (M, N)
+ """
+ return 1. / (1. + np.exp(np.clip(-x, -30, 30)))

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

Can you add a note on the clipping?

@arjoly

arjoly Aug 4, 2013

Member

Can you add a note on the clipping?

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

I removed the whole logistic method, I will be using a faster implementation from from sklearn.utils.extmath import logistic_sigmoid instead

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

I removed the whole logistic method, I will be using a faster implementation from from sklearn.utils.extmath import logistic_sigmoid instead

sklearn/neural_network/mlp.py
+
+ Parameters
+ ----------
+ x: array-like, shape (M, N)

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

To render probably in the documentation, a space is needed between "x" and ":".

Can you use more descriptive variable names for the shape?
For instance, n_samples n_features, n_classes, ... instead of (M, N)

@arjoly

arjoly Aug 4, 2013

Member

To render probably in the documentation, a space is needed between "x" and ":".

Can you use more descriptive variable names for the shape?
For instance, n_samples n_features, n_classes, ... instead of (M, N)

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

For a proper rendering of the doc, numpy conventions are used
https://github.com/numpy/numpy/blob/master/doc/HOWTO_DOCUMENT.rst.txt#docstring-standard

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 6, 2013

Contributor

Thanks, that makes a lot more sense.

@IssamLaradji

IssamLaradji Aug 6, 2013

Contributor

Thanks, that makes a lot more sense.

sklearn/neural_network/mlp.py
+ return (X.T - logsumexp(X, axis=1)).T
+
+
+def _softmax(X):

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

Why don't you take the exponential of _log_softmax?
Which one is the more stable numerically?

@arjoly

arjoly Aug 4, 2013

Member

Why don't you take the exponential of _log_softmax?
Which one is the more stable numerically?

sklearn/neural_network/mlp.py
+ Warning: This class should not be used directly.
+ Use derived classes instead.
+ """
+ """Multi-layer perceptron (feedforward neural network) classifier.

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

This docstring doesn't not seem to be at the right place.

@arjoly

arjoly Aug 4, 2013

Member

This docstring doesn't not seem to be at the right place.

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

Thanks, I moved it to the Classifier instead of the base

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

Thanks, I moved it to the Classifier instead of the base

sklearn/neural_network/__init__.py
"""
The :mod:`sklearn.neural_network` module includes models based on neural
networks.
"""
from .rbm import BernoulliRBM
+from .mlp import MultilayerPerceptronClassifier

This comment has been minimized.

@arjoly

arjoly Aug 4, 2013

Member

Can you add an __all__ variable?

@arjoly

arjoly Aug 4, 2013

Member

Can you add an __all__ variable?

This comment has been minimized.

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

done

@IssamLaradji

IssamLaradji Aug 5, 2013

Contributor

done

@arjoly

This comment has been minimized.

Show comment
Hide comment
@arjoly

arjoly Aug 4, 2013

Member

Can you add tests?

Member

arjoly commented Aug 4, 2013

Can you add tests?

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Aug 4, 2013

Member

Minor remark: I would prefer if the file name wasn't 'mlp.py'. Acronyms are to be avoided as much as possible ('rbm.py' was already quite near the limit, but it is such a standard acronym that we went for it).

Member

GaelVaroquaux commented Aug 4, 2013

Minor remark: I would prefer if the file name wasn't 'mlp.py'. Acronyms are to be avoided as much as possible ('rbm.py' was already quite near the limit, but it is such a standard acronym that we went for it).

IssamLaradji added some commits Jun 30, 2013

A lot of improvements
1) Optimization_method parameter for selecting any scipy optimizer
2) SGD is also supported
3) Improved minibatch creation using scikit's gen_even_slices (much
faster)
4) Support of loss functions
More fixes
Renames, seperated the algorithm to `classifier` and `regression`
Added an example to illustrate MLP performance on the digits dataset,…
… verbose for SGD and minibatch processing for l-bfgs
Added an example to illustrate MLP performance on the digits dataset,…
… verbose for SGD and minibatch processing for l-bfgs

@IssamLaradji IssamLaradji referenced this pull request May 27, 2014

Closed

[MRG] Generic multi layer perceptron #3204

3 of 4 tasks complete
@IssamLaradji

This comment has been minimized.

Show comment
Hide comment
@IssamLaradji

IssamLaradji May 27, 2014

Contributor

Hi guys, I am closing this pull-request because of the very long discussion.
Here is the new pull request: #3204.

Thanks

Contributor

IssamLaradji commented May 27, 2014

Hi guys, I am closing this pull-request because of the very long discussion.
Here is the new pull request: #3204.

Thanks

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment