New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG+3] Add mean absolute error splitting criterion to DecisionTreeRegressor #6667

Merged
merged 75 commits into from Jul 25, 2016

Conversation

Projects
None yet
@nelson-liu
Contributor

nelson-liu commented Apr 15, 2016

Adding mean absolute error criteria to tree.DecisionTreeRegressor

Sorry for the long silence, the past few weeks have been busy with the start of a new quarter. Things have settled down, though, and I'm ready to resume contributions.
I spent the past few days reading through and trying to get a handle on the tree module, and I've begun looking into implementing the mean absolute error (MAE) split criterion into the DecisionTreeRegressor. I'm creating this WIP PR to provide a public ground for discussion about the code; I believe that feedback early, fail fast would help maximize the amount of learning I can gain from this PR to apply toward future contributions.

Here's a task list of sub-objectives (that I see) to complete:

  • override node_value method to calculate the median
  • I have an initial version of node_value in my initial commit, please let me know if I'm on the right track / if there are things I should fix or can improve in functional correctness, efficiency, and style.
  • update node_impurity to return the mean absolute error
  • write the children_impurity method

I've never used C / C++ before, so I've been learning and experimenting with C and Cython as well. If you see a segment of code that looks incorrect, please point it out! I'm looking forward to learning more about Cython and C through this PR.

Thanks!

Show outdated Hide outdated sklearn/tree/_criterion.pyx
y_vals = NULL
weights = NULL
y_vals = <double*> calloc(self.n_node_samples, sizeof(double))

This comment has been minimized.

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

Not sure if this is the correct way to instantiate some sort of collection object in cython...

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

Not sure if this is the correct way to instantiate some sort of collection object in cython...

Show outdated Hide outdated sklearn/tree/_criterion.pyx
cdef SIZE_t* samples = self.samples
cdef DOUBLE_t* y = self.y
cdef SIZE_t start = self.start

This comment has been minimized.

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

I noticed that other code in this file does something like this quite commonly, setting a local variable to the value of a class variable. Is the any reason not to use self.start directly here?

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

I noticed that other code in this file does something like this quite commonly, setting a local variable to the value of a class variable. Is the any reason not to use self.start directly here?

This comment has been minimized.

@raghavrv

raghavrv Jul 21, 2016

Member

It's just for readability.

@raghavrv

raghavrv Jul 21, 2016

Member

It's just for readability.

Show outdated Hide outdated sklearn/tree/_criterion.pyx
with gil:
print "calculated weighted median:"
print y_val_pointer[median_index]
dest[k] = y_val_pointer[median_index]

This comment has been minimized.

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

The code runs into a Bus Error: 10 at this line. Perhaps this is because my destination pointer as called in node_impurity() isn't a collection / array?

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

The code runs into a Bus Error: 10 at this line. Perhaps this is because my destination pointer as called in node_impurity() isn't a collection / array?

Show outdated Hide outdated sklearn/tree/_criterion.pyx
with gil:
print "normally this isn't printed because of bus error: 10"
cdef double node_impurity(self) nogil:

This comment has been minimized.

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

I couldn't come up with any clever way to calculate the mean absolute error, so node_impurity just manually sums up all the differences between y_i and the median.

@nelson-liu

nelson-liu Apr 17, 2016

Contributor

I couldn't come up with any clever way to calculate the mean absolute error, so node_impurity just manually sums up all the differences between y_i and the median.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 17, 2016

Contributor

So I wrote an initial version of node_impurity, and fixed several issues in node_value as well. However, I am getting a Bus Error: 10 on line 1019 of _criterion.pyx, when the value of the pointer is assigned. The print statement indicates that the weighted median is being calculated correctly, I'm just unable to store it in the dest pointer. Could anyone let me know what I'm doing incorrectly?

Also, I sprinkled the diff with some comments and general questions I had about Cython, it'd be great if someone could answer them.

Contributor

nelson-liu commented Apr 17, 2016

So I wrote an initial version of node_impurity, and fixed several issues in node_value as well. However, I am getting a Bus Error: 10 on line 1019 of _criterion.pyx, when the value of the pointer is assigned. The print statement indicates that the weighted median is being calculated correctly, I'm just unable to store it in the dest pointer. Could anyone let me know what I'm doing incorrectly?

Also, I sprinkled the diff with some comments and general questions I had about Cython, it'd be great if someone could answer them.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 18, 2016

Contributor

I fixed the Bus Error: 10, and got node_value to correctly compute the weighted median when the data is sorted. I'm still thinking of ways to best generalize the function to work with unsorted data, as i'm sure there's a better solution than sorting and then running the current algorithm. Does anyone have any ideas?
Also, another current broken-ish behavior is to return a median equal to the first element seen, e.g. med([1,2,3,4]) would return 2. Although this makes no difference in the overall MAE calculation, it's still necessary to fix for exporting the tree to graphviz. This is fairly easy to customize, though.

Contributor

nelson-liu commented Apr 18, 2016

I fixed the Bus Error: 10, and got node_value to correctly compute the weighted median when the data is sorted. I'm still thinking of ways to best generalize the function to work with unsorted data, as i'm sure there's a better solution than sorting and then running the current algorithm. Does anyone have any ideas?
Also, another current broken-ish behavior is to return a median equal to the first element seen, e.g. med([1,2,3,4]) would return 2. Although this makes no difference in the overall MAE calculation, it's still necessary to fix for exporting the tree to graphviz. This is fairly easy to customize, though.

@maniteja123

This comment has been minimized.

Show comment
Hide comment
@maniteja123

maniteja123 Apr 18, 2016

Contributor

Hi, sorry for the noise but for the purpose of median_absolute_error, as per the interpolation method in wikipedia, there is an implementation to calculate weighted median in here. I am not aware of the intricate details here but I suppose it can be coded in Cython here on the same lines. Hope it helps some way.
EDIT: Changed the link. Sorry about that !

Contributor

maniteja123 commented Apr 18, 2016

Hi, sorry for the noise but for the purpose of median_absolute_error, as per the interpolation method in wikipedia, there is an implementation to calculate weighted median in here. I am not aware of the intricate details here but I suppose it can be coded in Cython here on the same lines. Hope it helps some way.
EDIT: Changed the link. Sorry about that !

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 18, 2016

Contributor

Ah I see, the code you linked goes ahead and sorts the array before applying a similar algorithm to the one I have. I was wondering if there was a way to calculate weighted median without sorting, but it seems reasonable that part of the problem of calculating median itself requires sorting (as you're partitioning into two halves, one of which is lesser than the other). Thanks for the input @maniteja123!

Contributor

nelson-liu commented Apr 18, 2016

Ah I see, the code you linked goes ahead and sorts the array before applying a similar algorithm to the one I have. I was wondering if there was a way to calculate weighted median without sorting, but it seems reasonable that part of the problem of calculating median itself requires sorting (as you're partitioning into two halves, one of which is lesser than the other). Thanks for the input @maniteja123!

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 19, 2016

Contributor

I implemented a generic quicksort algorithm to sort the array of y values and their weights, and corrected a bug that would cause the median to be incorrectly calculated when you have to take the average of two bounding points. node_value should be functionally correct now, but I'm sure there are ways to make it faster / more efficient.

My next steps will be to verify the correctness of the MAE implementation in node_impurity and write proxy_impurity and children_impurity

Contributor

nelson-liu commented Apr 19, 2016

I implemented a generic quicksort algorithm to sort the array of y values and their weights, and corrected a bug that would cause the median to be incorrectly calculated when you have to take the average of two bounding points. node_value should be functionally correct now, but I'm sure there are ways to make it faster / more efficient.

My next steps will be to verify the correctness of the MAE implementation in node_impurity and write proxy_impurity and children_impurity

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 19, 2016

Contributor

So i ran a few tests where I used toy data and calculated the MAE by hand, and compared my results to the code's results. I'll write unit tests in a bit, but I think node_impurity seems to be functioning.

Contributor

nelson-liu commented Apr 19, 2016

So i ran a few tests where I used toy data and calculated the MAE by hand, and compared my results to the code's results. I'll write unit tests in a bit, but I think node_impurity seems to be functioning.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 19, 2016

Contributor

I just wrote an initial version of children_impurity, but it's untested and unrefined. I also noticed a lot of code overlap between node_value and children_impurity, so I'm considering splitting the median calculation into a separate function. However, the other criterion have a fair amount of overlapping code as well...
I'm also not quite sure how to implement proxy_impurity_improvement, as I can't think of any proxies that would work for MAE yet. Does anyone have any input? perhaps @rvraghav93 @jmschrei @glouppe ?

Contributor

nelson-liu commented Apr 19, 2016

I just wrote an initial version of children_impurity, but it's untested and unrefined. I also noticed a lot of code overlap between node_value and children_impurity, so I'm considering splitting the median calculation into a separate function. However, the other criterion have a fair amount of overlapping code as well...
I'm also not quite sure how to implement proxy_impurity_improvement, as I can't think of any proxies that would work for MAE yet. Does anyone have any input? perhaps @rvraghav93 @jmschrei @glouppe ?

Show outdated Hide outdated sklearn/tree/_criterion.pyx
median_dest[k] = y_vals[median_index]
cdef void sort_values_and_weights(self, double* y_vals, double* weights,

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

Could you use DOUBLE_t everywhere?

@raghavrv

raghavrv Apr 20, 2016

Member

Could you use DOUBLE_t everywhere?

This comment has been minimized.

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

Sure, what are the advantages to using DOUBLE_T versus double?

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

Sure, what are the advantages to using DOUBLE_T versus double?

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

double is the standard c double. The size of that is almost always 64 bits. But it is not guaranteed to be so. As it is dependent on the compiler and/or platform.

Here DOUBLE_t is basically ctypedef cnp.npy_float64 DOUBLE_t, where we guarantee fix the size to be of 64 bits to prevent unwanted sideeffects.

@raghavrv

raghavrv Apr 20, 2016

Member

double is the standard c double. The size of that is almost always 64 bits. But it is not guaranteed to be so. As it is dependent on the compiler and/or platform.

Here DOUBLE_t is basically ctypedef cnp.npy_float64 DOUBLE_t, where we guarantee fix the size to be of 64 bits to prevent unwanted sideeffects.

This comment has been minimized.

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

ah ok, that makes sense. Thanks for your explanation! Why is double used vs DOUBLE_t in other parts of the file, then? Should this be changed?

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

ah ok, that makes sense. Thanks for your explanation! Why is double used vs DOUBLE_t in other parts of the file, then? Should this be changed?

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

Hmm... Leave those as such. Whenever you make a change use DOUBLE_t.

@raghavrv

raghavrv Apr 20, 2016

Member

Hmm... Leave those as such. Whenever you make a change use DOUBLE_t.

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

Ah wait! use DOUBLE_t for y and weights as it will be taken from a numpy array. For anything that won't be taken from/converted back into a numpy array, you are free to use double. That is what is done. Sorry if my previous comment did not clarify that.

@raghavrv

raghavrv Apr 20, 2016

Member

Ah wait! use DOUBLE_t for y and weights as it will be taken from a numpy array. For anything that won't be taken from/converted back into a numpy array, you are free to use double. That is what is done. Sorry if my previous comment did not clarify that.

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

Can @glouppe or @jnothman confirm if I am right?

@raghavrv

raghavrv Apr 20, 2016

Member

Can @glouppe or @jnothman confirm if I am right?

This comment has been minimized.

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

is there any reason not to just use DOUBLE_t for everything?

@nelson-liu

nelson-liu Apr 20, 2016

Contributor

is there any reason not to just use DOUBLE_t for everything?

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

I think DOUBLE_t comes with an overhead.

@raghavrv

raghavrv Apr 20, 2016

Member

I think DOUBLE_t comes with an overhead.

This comment has been minimized.

@nelson-liu

nelson-liu Apr 21, 2016

Contributor

I've changed the double / DOUBLE_t as appropriate, can you let me know if I missed anything?

@nelson-liu

nelson-liu Apr 21, 2016

Contributor

I've changed the double / DOUBLE_t as appropriate, can you let me know if I missed anything?

Show outdated Hide outdated sklearn/tree/_criterion.pyx
"""Evaluate the impurity of the current node, i.e. the impurity of
samples[start:end]"""
cdef double* medians = <double *> calloc(self.n_outputs, sizeof(double))
cdef double impurity = 0.0

This comment has been minimized.

@raghavrv

raghavrv Apr 20, 2016

Member

same comment (Use DOUBLE_t everywhere) Here its okay to use double...

@raghavrv

raghavrv Apr 20, 2016

Member

same comment (Use DOUBLE_t everywhere) Here its okay to use double...

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Apr 21, 2016

Member

A few comments - sort_values... and compute_medians... should be unbound helpers probably put inside the utils...

Member

raghavrv commented Apr 21, 2016

A few comments - sort_values... and compute_medians... should be unbound helpers probably put inside the utils...

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Apr 21, 2016

Contributor

@rvraghav93 ok, I'll change that. it'll probably require a slight reworking due to inability to use class variables of criterion from utils, but that should be trivial if not a bit verbose.

Contributor

nelson-liu commented Apr 21, 2016

@rvraghav93 ok, I'll change that. it'll probably require a slight reworking due to inability to use class variables of criterion from utils, but that should be trivial if not a bit verbose.

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Apr 21, 2016

Member

It should be something like how splitter.pyx uses the sort helper. For now just put it inside the criterion.pyx let's move it to utils later. I'll comment more specifically by this weekend.

Member

raghavrv commented Apr 21, 2016

It should be something like how splitter.pyx uses the sort helper. For now just put it inside the criterion.pyx let's move it to utils later. I'll comment more specifically by this weekend.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu May 12, 2016

Contributor

Sorry for not working on this for a bit, I went through and tinkered with the tree module code and Cython for the past two weeks and I feel like i have a much better grasp of both that will help me throughout both this PR and future tree work.

I moved the sorting functions to _utils.pyx and also fixed an issue with an incorrect type cast for double* vs DOUBLE_t*. I ran some quick examples with node_value and node_impurity, and I'm quite confident that they correctly compute the median and MAE, respectively. I'll revisit my implementation of children_impurity and write proxy_impurity_improvement shortly.

In the meanwhile, I'd appreciate if someone could look over my node_value and node_impurity functions and let me know if there's anything I need to change after a cursory glance; it's always nice to have a pair of fresh eyes.

Contributor

nelson-liu commented May 12, 2016

Sorry for not working on this for a bit, I went through and tinkered with the tree module code and Cython for the past two weeks and I feel like i have a much better grasp of both that will help me throughout both this PR and future tree work.

I moved the sorting functions to _utils.pyx and also fixed an issue with an incorrect type cast for double* vs DOUBLE_t*. I ran some quick examples with node_value and node_impurity, and I'm quite confident that they correctly compute the median and MAE, respectively. I'll revisit my implementation of children_impurity and write proxy_impurity_improvement shortly.

In the meanwhile, I'd appreciate if someone could look over my node_value and node_impurity functions and let me know if there's anything I need to change after a cursory glance; it's always nice to have a pair of fresh eyes.

Show outdated Hide outdated sklearn/tree/_criterion.pyx
for p in range(start, pos):
i = samples[p]
y_ik = y[i * self.y_stride + k]
# impurity_left[0] += (fabs(y_ik - medians[k]) / (pos - start))

This comment has been minimized.

@nelson-liu

nelson-liu May 21, 2016

Contributor

I'm not sure why this works, but it does; if i do it with the statements that are commented out, there's a float rounding error.

@nelson-liu

nelson-liu May 21, 2016

Contributor

I'm not sure why this works, but it does; if i do it with the statements that are commented out, there's a float rounding error.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel May 26, 2016

Member

I am not familiar enough with the tree code to know if it's possible to do proxy impurity function for MAE.

For the segfault on 32 bit Python under windows as reported by appveyor, if you don't have windows at end are too lazy to install it in a VM (you will also need VS 2015 to build scikit-learn with Python 3.5), then you can try to reproduce the issue with anaconda 32 bit for linux (assuming you use linux to develop on scikit-learn).

Member

ogrisel commented May 26, 2016

I am not familiar enough with the tree code to know if it's possible to do proxy impurity function for MAE.

For the segfault on 32 bit Python under windows as reported by appveyor, if you don't have windows at end are too lazy to install it in a VM (you will also need VS 2015 to build scikit-learn with Python 3.5), then you can try to reproduce the issue with anaconda 32 bit for linux (assuming you use linux to develop on scikit-learn).

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu May 27, 2016

Contributor

@ogrisel thanks for the tip! i installed dependencies through pip on a fresh 32-bit ubuntu VM, and it worked perfectly. Does anaconda do something different, or are they functionally equivalent?
I'm going to try installing the necessary tools to build scikit-learn on a windows VM; initially I had set up cygwin / mingw but I don't think that's the way to do it, as that's essentially emulating linux? I'm currently setting up python 2.7/3.5 without cygwin and using straight powershell.

Contributor

nelson-liu commented May 27, 2016

@ogrisel thanks for the tip! i installed dependencies through pip on a fresh 32-bit ubuntu VM, and it worked perfectly. Does anaconda do something different, or are they functionally equivalent?
I'm going to try installing the necessary tools to build scikit-learn on a windows VM; initially I had set up cygwin / mingw but I don't think that's the way to do it, as that's essentially emulating linux? I'm currently setting up python 2.7/3.5 without cygwin and using straight powershell.

Show outdated Hide outdated sklearn/tree/_criterion.pyx
return ((self.weighted_n_node_samples / self.weighted_n_samples) *
(impurity - (self.weighted_n_right /
(impurity - (self.weighted_n_right /

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

Could you avoid modifying lines that are not related to the PR please? It works in your favor by speeding up the review as the diff is smaller...

@raghavrv

raghavrv May 27, 2016

Member

Could you avoid modifying lines that are not related to the PR please? It works in your favor by speeding up the review as the diff is smaller...

Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -263,7 +263,7 @@ cdef class ClassificationCriterion(Criterion):
self.sum_left = <double*> calloc(n_elements, sizeof(double))
self.sum_right = <double*> calloc(n_elements, sizeof(double))
if (self.sum_total == NULL or
if (self.sum_total == NULL or

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

same comment

@raghavrv

raghavrv May 27, 2016

Member

same comment

Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -722,7 +722,7 @@ cdef class RegressionCriterion(Criterion):
self.sum_left = <double*> calloc(n_outputs, sizeof(double))
self.sum_right = <double*> calloc(n_outputs, sizeof(double))
if (self.sum_total == NULL or
if (self.sum_total == NULL or

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

same comment

@raghavrv

raghavrv May 27, 2016

Member

same comment

Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -957,11 +954,108 @@ cdef class MSE(RegressionCriterion):
for k in range(self.n_outputs):
impurity_left[0] -= (sum_left[k] / self.weighted_n_left) ** 2.0
impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0
impurity_right[0] -= (sum_right[k] / self.weighted_n_right) ** 2.0

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

here too

@raghavrv

raghavrv May 27, 2016

Member

here too

Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -921,7 +919,6 @@ cdef class MSE(RegressionCriterion):
left child (samples[start:pos]) and the impurity the right child
(samples[pos:end])."""

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

this

@raghavrv
Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -886,7 +885,6 @@ cdef class MSE(RegressionCriterion):
impurity = self.sq_sum_total / self.weighted_n_node_samples
for k in range(self.n_outputs):
impurity -= (sum_total[k] / self.weighted_n_node_samples)**2.0

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

this

@raghavrv
Show outdated Hide outdated sklearn/tree/_criterion.pyx
@@ -847,7 +846,7 @@ cdef class RegressionCriterion(Criterion):
self.weighted_n_left -= w
self.weighted_n_right = (self.weighted_n_node_samples -
self.weighted_n_right = (self.weighted_n_node_samples -

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

and this ;)

@raghavrv

raghavrv May 27, 2016

Member

and this ;)

Show outdated Hide outdated sklearn/tree/_criterion.pyx
"""Evaluate the impurity of the current node, i.e. the impurity of
samples[start:end]"""
cdef double* medians = NULL
medians = <double*> calloc(self.n_outputs, sizeof(double))

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

Shouldn't you use saferealloc here?

@raghavrv

raghavrv May 27, 2016

Member

Shouldn't you use saferealloc here?

This comment has been minimized.

@nelson-liu

nelson-liu May 28, 2016

Contributor

hmm why do you suggest that? I'm not quite sure what the use case for saferealloc is, so I'm wondering.

@nelson-liu

nelson-liu May 28, 2016

Contributor

hmm why do you suggest that? I'm not quite sure what the use case for saferealloc is, so I'm wondering.

This comment has been minimized.

@glouppe

glouppe Jun 6, 2016

Member

always use the safe alloc

@glouppe

glouppe Jun 6, 2016

Member

always use the safe alloc

This comment has been minimized.

@nelson-liu

nelson-liu Jun 18, 2016

Contributor

hmm i can't seem to use safe_realloc here, cython always gives me a "no suitable method found" error. Could it be because safe_realloc is gil-requiring and i thus cannot use it in a nogil function?

@nelson-liu

nelson-liu Jun 18, 2016

Contributor

hmm i can't seem to use safe_realloc here, cython always gives me a "no suitable method found" error. Could it be because safe_realloc is gil-requiring and i thus cannot use it in a nogil function?

This comment has been minimized.

@glouppe

glouppe Jun 19, 2016

Member

You could allocate medians once at the beginning of the construciton of the tree, and then free it afterwards. This would make it possible to use safe_realloc while avoiding the allocation/deallocation dance that you are currently doing.

@glouppe

glouppe Jun 19, 2016

Member

You could allocate medians once at the beginning of the construciton of the tree, and then free it afterwards. This would make it possible to use safe_realloc while avoiding the allocation/deallocation dance that you are currently doing.

Show outdated Hide outdated sklearn/tree/_splitter.pyx
@@ -1563,6 +1567,7 @@ cdef class RandomSparseSplitter(BaseSparseSplitter):
# Evaluate split
self.criterion.reset()

This comment has been minimized.

@raghavrv

raghavrv May 27, 2016

Member

please avoid these changes. Bare essential diff is the best practice :)

@raghavrv

raghavrv May 27, 2016

Member

please avoid these changes. Bare essential diff is the best practice :)

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu May 28, 2016

Contributor

I've completed an initial working version of the MAE criterion; it passes unit tests, which is nice. I'd love if anyone could spare some time to review what I have.

I'll proceed to write some tests for it and make some benchmarks. MAE is likely going to be a lot slower than MSE because it doesn't have a proxy impurity function yet.

I've been talking to @rvraghav93 and @jmschrei and we aren't sure if it's possible to create proxy impurity function for a median calculation like the one in MSE, does anyone have any input?

Contributor

nelson-liu commented May 28, 2016

I've completed an initial working version of the MAE criterion; it passes unit tests, which is nice. I'd love if anyone could spare some time to review what I have.

I'll proceed to write some tests for it and make some benchmarks. MAE is likely going to be a lot slower than MSE because it doesn't have a proxy impurity function yet.

I've been talking to @rvraghav93 and @jmschrei and we aren't sure if it's possible to create proxy impurity function for a median calculation like the one in MSE, does anyone have any input?

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu May 29, 2016

Contributor

I'm not sure if there are any tests I should be writing to specifically address mae -- it's currently being tested and passing the tests that all Criterion are normally subjected to. I made quick demo script and recorded the results to demonstrate that the performance of MAE and MSE are comparable. However, MAE is far far slower due to the lack of a proxy function; any tips on how to improve its speed?
Here's the testing script used: https://gist.github.com/nelson-liu/f52005bb074ddf786e4332a20b8aad7a
The results are as follows:

Training and Testing on datasets.boston

full output here: https://gist.github.com/nelson-liu/ac2a041faeca9941331e205793eaf1c0

MSE time: 105 function calls in 0.004 seconds
MAE time:  105 function calls in 0.175 seconds

Mean Squared Error of Tree Trained w/ MSE Criterion: 32.257480315
Mean Squared Error of Tree Trained w/ MAE Criterion: 29.117480315

Mean Absolute Error of Tree Trained w/ MSE Criterion: 3.50551181102
Mean Absolute Error of Tree Trained w/ MAE Criterion: 3.36220472441

Training and Testing on randomly generated data from datasets.samples_generator.make_regression

full output here: https://gist.github.com/nelson-liu/6633b5a723debae59d647476764d8bd8

MSE time: 105 function calls in 0.089 seconds
MAE time:  105 function calls in 15.419 seconds

Mean Squared Error of Tree Trained w/ MSE Criterion: 0.702881265958
Mean Squared Error of Tree Trained w/ MAE Criterion: 0.66665916831

Mean Absolute Error of Tree Trained w/ MSE Criterion: 0.650976429446
Mean Absolute Error of Tree Trained w/ MAE Criterion: 0.657671579992
Contributor

nelson-liu commented May 29, 2016

I'm not sure if there are any tests I should be writing to specifically address mae -- it's currently being tested and passing the tests that all Criterion are normally subjected to. I made quick demo script and recorded the results to demonstrate that the performance of MAE and MSE are comparable. However, MAE is far far slower due to the lack of a proxy function; any tips on how to improve its speed?
Here's the testing script used: https://gist.github.com/nelson-liu/f52005bb074ddf786e4332a20b8aad7a
The results are as follows:

Training and Testing on datasets.boston

full output here: https://gist.github.com/nelson-liu/ac2a041faeca9941331e205793eaf1c0

MSE time: 105 function calls in 0.004 seconds
MAE time:  105 function calls in 0.175 seconds

Mean Squared Error of Tree Trained w/ MSE Criterion: 32.257480315
Mean Squared Error of Tree Trained w/ MAE Criterion: 29.117480315

Mean Absolute Error of Tree Trained w/ MSE Criterion: 3.50551181102
Mean Absolute Error of Tree Trained w/ MAE Criterion: 3.36220472441

Training and Testing on randomly generated data from datasets.samples_generator.make_regression

full output here: https://gist.github.com/nelson-liu/6633b5a723debae59d647476764d8bd8

MSE time: 105 function calls in 0.089 seconds
MAE time:  105 function calls in 15.419 seconds

Mean Squared Error of Tree Trained w/ MSE Criterion: 0.702881265958
Mean Squared Error of Tree Trained w/ MAE Criterion: 0.66665916831

Mean Absolute Error of Tree Trained w/ MSE Criterion: 0.650976429446
Mean Absolute Error of Tree Trained w/ MAE Criterion: 0.657671579992
@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jun 8, 2016

Member

Thanks for the benchmarks.

Could you try this criterion on ExtraTreesRegressor (the ensemble version) on some regression dataset and report the training time alone in comparison with MSE please? That should be much faster than using the best splitter. (And yield comparable results).

So we now need to speed this up and find a way to compute the MAE of either partitions efficiently without sorting at each threshold correct?

I think like Jacob suggested in the hangouts, you need to use a data structure to speed up the median computation...

Could you look into min-max-median heap?

Member

raghavrv commented Jun 8, 2016

Thanks for the benchmarks.

Could you try this criterion on ExtraTreesRegressor (the ensemble version) on some regression dataset and report the training time alone in comparison with MSE please? That should be much faster than using the best splitter. (And yield comparable results).

So we now need to speed this up and find a way to compute the MAE of either partitions efficiently without sorting at each threshold correct?

I think like Jacob suggested in the hangouts, you need to use a data structure to speed up the median computation...

Could you look into min-max-median heap?

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jun 8, 2016

Member

And I meant that your node_impurity function will be called more than once.

You could declare the median array at the cinit and store the pointer to the block of memory and reuse the same block of memory for all the median computations (as long as the median array size need not be changed... aka multiple calls to node impurity within same split).

If the node_impurity is called for a different sized split, then you just call (at init) safe_realloc, to resize this block of memory to the new size.

This would be 1. cleaner and 2. safer instead of allocating and releasing the memory inside node_impurity

Member

raghavrv commented Jun 8, 2016

And I meant that your node_impurity function will be called more than once.

You could declare the median array at the cinit and store the pointer to the block of memory and reuse the same block of memory for all the median computations (as long as the median array size need not be changed... aka multiple calls to node impurity within same split).

If the node_impurity is called for a different sized split, then you just call (at init) safe_realloc, to resize this block of memory to the new size.

This would be 1. cleaner and 2. safer instead of allocating and releasing the memory inside node_impurity

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jun 8, 2016

Member

As for the appveyor failure in 32 bit version, are you assuming any array to be of 64 bits wide and accessing it with that assumption?

Member

raghavrv commented Jun 8, 2016

As for the appveyor failure in 32 bit version, are you assuming any array to be of 64 bits wide and accessing it with that assumption?

Show outdated Hide outdated sklearn/tree/_criterion.pyx
samples[start:end]"""
cdef double* medians = NULL
medians = <double*> calloc(self.n_outputs, sizeof(double))
if (medians == NULL):

This comment has been minimized.

@raghavrv

raghavrv Jun 8, 2016

Member

You can remove the brackets

@raghavrv

raghavrv Jun 8, 2016

Member

You can remove the brackets

This comment has been minimized.

@nelson-liu

nelson-liu Jun 18, 2016

Contributor

good catch, thanks. java habits spreading...

@nelson-liu

nelson-liu Jun 18, 2016

Contributor

good catch, thanks. java habits spreading...

nelson-liu added some commits Apr 15, 2016

testing code for node_impurity and node_value
This code runs into 'Bus Error: 10' at node_value final assignment.
fix: node_value now correctly calculating weighted median for sorted …
…data.

Still need to change the code to work with unsorted data.
@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Jul 21, 2016

Contributor

In the latest commit, I added a n_samples argument to the __cinit__ of the regression criterion (although it is only used in MAE). This lets us directly build the WeightedMedianCalculators while we still have the GIL, and thus we do not have to reacquire it later. As always, I reran benchmarks and they're below.

The scripts I used to benchmark can be found here: https://github.com/nelson-liu/sklearn_dev_scripts/tree/master/mae_mse_benchmarks_6667

Training and Testing (with .25 heldout data) on datasets.boston

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/nogil_alloc_once/benchmark_mae_mse_boston_nogil_alloc_once_results.txt

MSE time: 104 function calls in 0.003 seconds
MAE time:  104 function calls in 0.032 seconds (was: 0.038 seconds in previous benchmark)
accuracies remain the same

Training and Testing (0.25 held out) on randomly generated data from datasets.samples_generator.make_regression (int(1e3) samples with int(1e2) features)

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/nogil_alloc_once/benchmark_mae_mse_generated_nogil_alloc_once_results.txt

MSE time: 104 function calls in 0.065 seconds
MAE time:  104 function calls in 0.961 seconds (was: 0.978 seconds in previous benchmark)
accuracies once again remain the same

So we got some incremental improvements from this change, but they're improvements nonetheless :)

Contributor

nelson-liu commented Jul 21, 2016

In the latest commit, I added a n_samples argument to the __cinit__ of the regression criterion (although it is only used in MAE). This lets us directly build the WeightedMedianCalculators while we still have the GIL, and thus we do not have to reacquire it later. As always, I reran benchmarks and they're below.

The scripts I used to benchmark can be found here: https://github.com/nelson-liu/sklearn_dev_scripts/tree/master/mae_mse_benchmarks_6667

Training and Testing (with .25 heldout data) on datasets.boston

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/nogil_alloc_once/benchmark_mae_mse_boston_nogil_alloc_once_results.txt

MSE time: 104 function calls in 0.003 seconds
MAE time:  104 function calls in 0.032 seconds (was: 0.038 seconds in previous benchmark)
accuracies remain the same

Training and Testing (0.25 held out) on randomly generated data from datasets.samples_generator.make_regression (int(1e3) samples with int(1e2) features)

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/nogil_alloc_once/benchmark_mae_mse_generated_nogil_alloc_once_results.txt

MSE time: 104 function calls in 0.065 seconds
MAE time:  104 function calls in 0.961 seconds (was: 0.978 seconds in previous benchmark)
accuracies once again remain the same

So we got some incremental improvements from this change, but they're improvements nonetheless :)

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jul 21, 2016

Member

@nelson-liu This is good work. I think we can live with 10x slower than MSE. The primary reason why MSE is faster is because of the clever proxy_impurity function...

Also have you benchmarked it comparing MSE and MAE for ExtraTreesRegressor?

Member

raghavrv commented Jul 21, 2016

@nelson-liu This is good work. I think we can live with 10x slower than MSE. The primary reason why MSE is faster is because of the clever proxy_impurity function...

Also have you benchmarked it comparing MSE and MAE for ExtraTreesRegressor?

@jmschrei

This comment has been minimized.

Show comment
Hide comment
@jmschrei

jmschrei Jul 21, 2016

Member

@raghavrv brings up the rare good point. I think this is pretty much ready, but lets just use all possible models with MSE and MAE to have a record of the speed differences and accuracy differences this PR provides.

@nelson-liu does that sound do-able to you? This should include decision trees, random forests, extra trees, and gradient boosting.

Member

jmschrei commented Jul 21, 2016

@raghavrv brings up the rare good point. I think this is pretty much ready, but lets just use all possible models with MSE and MAE to have a record of the speed differences and accuracy differences this PR provides.

@nelson-liu does that sound do-able to you? This should include decision trees, random forests, extra trees, and gradient boosting.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Jul 21, 2016

Contributor

Here are the benchmarks requested for MAE vs MSE on various models.

The scripts I used to benchmark can be found here: https://github.com/nelson-liu/sklearn_dev_scripts/tree/master/mae_mse_benchmarks_6667/bench_all

All models run on default parameters except for random_state and criterion

Training and Testing (with .25 heldout data) on datasets.boston

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_boston_all_results.txt

DecisionTreeRegressor:
MSE time: 104 function calls in 0.004 seconds
MAE time:  104 function calls in 0.032 seconds
Mean Squared Error of DecisionTreeRegressor Trained w/ MSE Criterion: 32.257480315
Mean Squared Error of DecisionTreeRegressor Trained w/ MAE Criterion: 29.117480315
Mean Absolute Error of DecisionTreeRegressor Trained w/ MSE Criterion: 3.50551181102
Mean Absolute Error of DecisionTreeRegressor Trained w/ MAE Criterion: 3.36220472441

RandomForestRegressor
MSE time: 23068 function calls (22928 primitive calls) in 0.048 seconds
MAE time:  23068 function calls (22928 primitive calls) in 0.213 seconds
Mean Squared Error of RandomForestRegressor Trained w/ MSE Criterion: 19.7881511811
Mean Squared Error of RandomForestRegressor Trained w/ MAE Criterion: 17.354603937
Mean Absolute Error of RandomForestRegressor Trained w/ MSE Criterion: 2.62362204724
Mean Absolute Error of RandomForestRegressor Trained w/ MAE Criterion: 2.65125984252

GradientBoostingRegressor
FriedmanMSE time: 7482 function calls in 0.046 seconds
MAE time:  7461 function calls in 1.485 seconds
Mean Squared Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 15.0091045221
Mean Squared Error of GradientBoostingRegressor Trained w/ MAE Criterion: 17.4920902086
Mean Absolute Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 2.5279874482
Mean Absolute Error of GradientBoostingRegressor Trained w/ MAE Criterion: 2.57534507477

ExtraTreesRegressor
MSE time: 21348 function calls (21218 primitive calls) in 0.033 seconds
MAE time:  21348 function calls (21218 primitive calls) in 0.204 seconds
Mean Squared Error of ExtraTreesRegressor Trained w/ MSE Criterion: 22.7431417323
Mean Squared Error of ExtraTreesRegressor Trained w/ MAE Criterion: 18.5299811024
Mean Absolute Error of ExtraTreesRegressor Trained w/ MSE Criterion: 2.70251968504
Mean Absolute Error of ExtraTreesRegressor Trained w/ MAE Criterion: 2.58818897638

Training and Testing (0.25 held out) on randomly generated data from datasets.samples_generator.make_regression (int(1e3) samples with int(1e2) features)

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_generated_all_results.txt

DecisionTreeRegressor:
MSE time: 104 function calls in 0.081 seconds
MAE time:  104 function calls in 1.058 seconds
Mean Squared Error of DecisionTreeRegressor Trained w/ MSE Criterion: 0.702881265958
Mean Squared Error of DecisionTreeRegressor Trained w/ MAE Criterion: 0.66665916831
Mean Absolute Error of DecisionTreeRegressor Trained w/ MSE Criterion: 0.650976429446
Mean Absolute Error of DecisionTreeRegressor Trained w/ MAE Criterion: 0.657671579992

RandomForestRegressor
MSE time: 23068 function calls (22928 primitive calls) in 0.488 seconds
MAE time:  23068 function calls (22928 primitive calls) in 5.368 seconds
Mean Squared Error of RandomForestRegressor Trained w/ MSE Criterion: 0.373731811322
Mean Squared Error of RandomForestRegressor Trained w/ MAE Criterion: 0.420901485605
Mean Absolute Error of RandomForestRegressor Trained w/ MSE Criterion: 0.512413946409
Mean Absolute Error of RandomForestRegressor Trained w/ MAE Criterion: 0.515708186449

GradientBoostingRegressor
FriedmanMSE time: 7482 function calls in 2.2126 seconds
MAE time:  7461 function calls in 73.409 seconds
Mean Squared Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 0.101990638097
Mean Squared Error of GradientBoostingRegressor Trained w/ MAE Criterion: 0.116103550478
Mean Absolute Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 0.252211765258
Mean Absolute Error of GradientBoostingRegressor Trained w/ MAE Criterion: 0.266575379065

ExtraTreesRegressor
MSE time: 21348 function calls (21218 primitive calls) in 0.215 seconds
MAE time:  21348 function calls (21218 primitive calls) in 3.934 seconds
Mean Squared Error of ExtraTreesRegressor Trained w/ MSE Criterion: 0.284074991003
Mean Squared Error of ExtraTreesRegressor Trained w/ MAE Criterion: 0.378661973095
Mean Absolute Error of ExtraTreesRegressor Trained w/ MSE Criterion: 0.446840085052
Mean Absolute Error of ExtraTreesRegressor Trained w/ MAE Criterion: 0.495304762963


Contributor

nelson-liu commented Jul 21, 2016

Here are the benchmarks requested for MAE vs MSE on various models.

The scripts I used to benchmark can be found here: https://github.com/nelson-liu/sklearn_dev_scripts/tree/master/mae_mse_benchmarks_6667/bench_all

All models run on default parameters except for random_state and criterion

Training and Testing (with .25 heldout data) on datasets.boston

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_boston_all_results.txt

DecisionTreeRegressor:
MSE time: 104 function calls in 0.004 seconds
MAE time:  104 function calls in 0.032 seconds
Mean Squared Error of DecisionTreeRegressor Trained w/ MSE Criterion: 32.257480315
Mean Squared Error of DecisionTreeRegressor Trained w/ MAE Criterion: 29.117480315
Mean Absolute Error of DecisionTreeRegressor Trained w/ MSE Criterion: 3.50551181102
Mean Absolute Error of DecisionTreeRegressor Trained w/ MAE Criterion: 3.36220472441

RandomForestRegressor
MSE time: 23068 function calls (22928 primitive calls) in 0.048 seconds
MAE time:  23068 function calls (22928 primitive calls) in 0.213 seconds
Mean Squared Error of RandomForestRegressor Trained w/ MSE Criterion: 19.7881511811
Mean Squared Error of RandomForestRegressor Trained w/ MAE Criterion: 17.354603937
Mean Absolute Error of RandomForestRegressor Trained w/ MSE Criterion: 2.62362204724
Mean Absolute Error of RandomForestRegressor Trained w/ MAE Criterion: 2.65125984252

GradientBoostingRegressor
FriedmanMSE time: 7482 function calls in 0.046 seconds
MAE time:  7461 function calls in 1.485 seconds
Mean Squared Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 15.0091045221
Mean Squared Error of GradientBoostingRegressor Trained w/ MAE Criterion: 17.4920902086
Mean Absolute Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 2.5279874482
Mean Absolute Error of GradientBoostingRegressor Trained w/ MAE Criterion: 2.57534507477

ExtraTreesRegressor
MSE time: 21348 function calls (21218 primitive calls) in 0.033 seconds
MAE time:  21348 function calls (21218 primitive calls) in 0.204 seconds
Mean Squared Error of ExtraTreesRegressor Trained w/ MSE Criterion: 22.7431417323
Mean Squared Error of ExtraTreesRegressor Trained w/ MAE Criterion: 18.5299811024
Mean Absolute Error of ExtraTreesRegressor Trained w/ MSE Criterion: 2.70251968504
Mean Absolute Error of ExtraTreesRegressor Trained w/ MAE Criterion: 2.58818897638

Training and Testing (0.25 held out) on randomly generated data from datasets.samples_generator.make_regression (int(1e3) samples with int(1e2) features)

full output here: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_generated_all_results.txt

DecisionTreeRegressor:
MSE time: 104 function calls in 0.081 seconds
MAE time:  104 function calls in 1.058 seconds
Mean Squared Error of DecisionTreeRegressor Trained w/ MSE Criterion: 0.702881265958
Mean Squared Error of DecisionTreeRegressor Trained w/ MAE Criterion: 0.66665916831
Mean Absolute Error of DecisionTreeRegressor Trained w/ MSE Criterion: 0.650976429446
Mean Absolute Error of DecisionTreeRegressor Trained w/ MAE Criterion: 0.657671579992

RandomForestRegressor
MSE time: 23068 function calls (22928 primitive calls) in 0.488 seconds
MAE time:  23068 function calls (22928 primitive calls) in 5.368 seconds
Mean Squared Error of RandomForestRegressor Trained w/ MSE Criterion: 0.373731811322
Mean Squared Error of RandomForestRegressor Trained w/ MAE Criterion: 0.420901485605
Mean Absolute Error of RandomForestRegressor Trained w/ MSE Criterion: 0.512413946409
Mean Absolute Error of RandomForestRegressor Trained w/ MAE Criterion: 0.515708186449

GradientBoostingRegressor
FriedmanMSE time: 7482 function calls in 2.2126 seconds
MAE time:  7461 function calls in 73.409 seconds
Mean Squared Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 0.101990638097
Mean Squared Error of GradientBoostingRegressor Trained w/ MAE Criterion: 0.116103550478
Mean Absolute Error of GradientBoostingRegressor Trained w/ Friedman MSE Criterion: 0.252211765258
Mean Absolute Error of GradientBoostingRegressor Trained w/ MAE Criterion: 0.266575379065

ExtraTreesRegressor
MSE time: 21348 function calls (21218 primitive calls) in 0.215 seconds
MAE time:  21348 function calls (21218 primitive calls) in 3.934 seconds
Mean Squared Error of ExtraTreesRegressor Trained w/ MSE Criterion: 0.284074991003
Mean Squared Error of ExtraTreesRegressor Trained w/ MAE Criterion: 0.378661973095
Mean Absolute Error of ExtraTreesRegressor Trained w/ MSE Criterion: 0.446840085052
Mean Absolute Error of ExtraTreesRegressor Trained w/ MAE Criterion: 0.495304762963


@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jul 22, 2016

Member

I find it strange that MSE is often lower when minimizing MAE than MSE directly. What do you observe for the error on training data? (on training data, MSE should be lower when minimizing MSE, and similarly for MAE )

Member

glouppe commented Jul 22, 2016

I find it strange that MSE is often lower when minimizing MAE than MSE directly. What do you observe for the error on training data? (on training data, MSE should be lower when minimizing MSE, and similarly for MAE )

@jmschrei

This comment has been minimized.

Show comment
Hide comment
@jmschrei

jmschrei Jul 22, 2016

Member

I can't provide a theoretical proof, but it's been a common finding for me that minimizing MAE can produce better MSE results in practice.

Member

jmschrei commented Jul 22, 2016

I can't provide a theoretical proof, but it's been a common finding for me that minimizing MAE can produce better MSE results in practice.

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Jul 22, 2016

Contributor

@glouppe I ran some tests here, but I don't have the time right now to pick out the relevant bits of information in this comment

Boston: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_boston_all_results_train.txt
Generated: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_generated_all_results_train.txt

Interestingly, it's not always the case that MSE is lower than MAE when comparing on MSE, and MAE is lower than MSE when comparing on MAE. MSE seems to be just better (in both datasets) than MAE with RandomForests and GradientBoosting

Contributor

nelson-liu commented Jul 22, 2016

@glouppe I ran some tests here, but I don't have the time right now to pick out the relevant bits of information in this comment

Boston: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_boston_all_results_train.txt
Generated: https://github.com/nelson-liu/sklearn_dev_scripts/blob/master/mae_mse_benchmarks_6667/results/bench_all/benchmark_mae_mse_generated_all_results_train.txt

Interestingly, it's not always the case that MSE is lower than MAE when comparing on MSE, and MAE is lower than MSE when comparing on MAE. MSE seems to be just better (in both datasets) than MAE with RandomForests and GradientBoosting

@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jul 22, 2016

Member

Thanks!

In any case +1 for merge on my side. Thanks a lot for this useful contribution :)

Member

glouppe commented Jul 22, 2016

Thanks!

In any case +1 for merge on my side. Thanks a lot for this useful contribution :)

@nelson-liu nelson-liu changed the title from [MRG] Add mean absolute error splitting criterion to DecisionTreeRegressor to [MRG+1] Add mean absolute error splitting criterion to DecisionTreeRegressor Jul 22, 2016

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jul 22, 2016

Member

Yohoo you get a +1!! Amazing job!

Member

raghavrv commented Jul 22, 2016

Yohoo you get a +1!! Amazing job!

@jmschrei

This comment has been minimized.

Show comment
Hide comment
@jmschrei

jmschrei Jul 22, 2016

Member

This looks solid and well documented/tested, with some minor improvements to the API of gradient boosting as well. Lets get this GSoC PR merged! I give my +1 as well.

Member

jmschrei commented Jul 22, 2016

This looks solid and well documented/tested, with some minor improvements to the API of gradient boosting as well. Lets get this GSoC PR merged! I give my +1 as well.

@nelson-liu nelson-liu changed the title from [MRG+1] Add mean absolute error splitting criterion to DecisionTreeRegressor to [MRG+3] Add mean absolute error splitting criterion to DecisionTreeRegressor Jul 22, 2016

@glouppe glouppe merged commit c84ff5e into scikit-learn:master Jul 25, 2016

4 checks passed

ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
coverage/coveralls Coverage increased (+0.006%) to 94.488%
Details
@glouppe

This comment has been minimized.

Show comment
Hide comment
@glouppe

glouppe Jul 25, 2016

Member

Merging this! 🍻

Member

glouppe commented Jul 25, 2016

Merging this! 🍻

@raghavrv

This comment has been minimized.

Show comment
Hide comment
@raghavrv

raghavrv Jul 25, 2016

Member

🍻

Member

raghavrv commented Jul 25, 2016

🍻

@GaelVaroquaux

This comment has been minimized.

Show comment
Hide comment
@GaelVaroquaux

GaelVaroquaux Jul 25, 2016

Member
Member

GaelVaroquaux commented Jul 25, 2016

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Jul 25, 2016

Contributor

Thanks all! 🎆 :shipit:

Contributor

nelson-liu commented Jul 25, 2016

Thanks all! 🎆 :shipit:

@amueller

This comment has been minimized.

Show comment
Hide comment
@amueller

amueller Jul 25, 2016

Member

🍻

Member

amueller commented Jul 25, 2016

🍻

@@ -1296,6 +1297,14 @@ class GradientBoostingClassifier(BaseGradientBoosting, ClassifierMixin):
of the input variables.
Ignored if ``max_leaf_nodes`` is not None.
criterion : string, optional (default="friedman_mse")

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

Sorry for being late to the party, but this should have a versionadded, right?

@amueller

amueller Jul 28, 2016

Member

Sorry for being late to the party, but this should have a versionadded, right?

This comment has been minimized.

@nelson-liu

nelson-liu Jul 28, 2016

Contributor

yes it should, i'll add that.

@nelson-liu

nelson-liu Jul 28, 2016

Contributor

yes it should, i'll add that.

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

thanks :)

@amueller

amueller Jul 28, 2016

Member

thanks :)

@@ -1643,6 +1652,14 @@ class GradientBoostingRegressor(BaseGradientBoosting, RegressorMixin):
of the input variables.
Ignored if ``max_leaf_nodes`` is not None.
criterion : string, optional (default="friedman_mse")

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

versionadded

@amueller

amueller Jul 28, 2016

Member

versionadded

The function to measure the quality of a split. The only supported
criterion is "mse" for the mean squared error, which is equal to
variance reduction as feature selection criterion.
The function to measure the quality of a split. Supported criteria

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

@@ -947,8 +947,10 @@ class RandomForestRegressor(ForestRegressor):
The number of trees in the forest.
criterion : string, optional (default="mse")
The function to measure the quality of a split. The only supported
criterion is "mse" for the mean squared error.
The function to measure the quality of a split. Supported criteria

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

@@ -1299,8 +1301,10 @@ class ExtraTreesRegressor(ForestRegressor):
The number of trees in the forest.
criterion : string, optional (default="mse")
The function to measure the quality of a split. The only supported
criterion is "mse" for the mean squared error.
The function to measure the quality of a split. Supported criteria

This comment has been minimized.

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

@amueller

amueller Jul 28, 2016

Member

I think this should have versionadded for mae

olologin added a commit to olologin/scikit-learn that referenced this pull request Aug 24, 2016

[MRG+3] Add mean absolute error splitting criterion to DecisionTreeRe…
…gressor (#6667)

* feature: add initial node_value method

* testing code for node_impurity and node_value

This code runs into 'Bus Error: 10' at node_value final assignment.

* fix: node_value now correctly calculating weighted median for sorted data.

Still need to change the code to work with unsorted data.

* fix: node_value now correctly calculates median regardless of initial order

* fix: correct bug in calculating median when taking midpoint is necessary

* feature: add initial version of children_impurity

* feature: refactor median calculation into one function

* fix: fix use of DOUBLE_t vs double

* feature: move helper functions to _utils.pyx, fix mismatched pointer type

* fix: fix some bugs in children_impurity method

* push a debug version to try to solve segfault

* push latest changes, segfault probably happening bc of something in _utils.pyx

* fix: fix segfault in median calculation and remove excessive logging

* chore: revert some misc spacing changes I accidentally made

* chore: one last spacing fix in _splitter.pyx

* feature: don't calculate weighted median if no weights are passed in

* remove extraneous logging statement

* fix: fix children impurity calculation

* fix: fix bug with children impurity not being initally set to 0

* fix: hacky fix for a float accuracy error

* fix: incorrect type cast in median array generation for node_impurity

* slightly tweak node_impurity function

* fix: be more explicit with casts

* feature: revert cosmetic changes and free temporary arrays

* fix: only free weight array in median calcuation if it was created

* style: remove extraneous newline / trigger CI build

* style: remove extraneous 0 from range

* feature: save sorts within a node to speed it up

* fix: move parts of dealloc to regression criterion

* chore: add comment to splitter to try to force recythonizing

* chore: add comment to _tree.pyx to try to force recythonizing

* chore: add empty comment to gradient boosting to force recythonizing

* fix: fix bug in weighted median

* try moving sorted values to a class variable

* feature: refactor criterion to sort once initially, then draw all samples from this sorted data

* style: remove extraneous parens from if condition

* implement median-heap method for calculating impurity

* style: remove extra line

* style: fix inadvertent cosmetic changes; i'll address some of these in a separate PR

* feature: change minmaxheap to internally use sorted arrays

* refactored MAE and push to share work

* fix errors wrt median insertion case

* spurious comment to force recythonization

* general code cleanup

* fix typo in _tree.pyx

* removed some extraneous comments

* [ci skip] remove earlier microchanges

* [ci skip] remove change to priorityheap

* [ci skip] fix indentation

* [ci skip] fix class-specific issues with heaps

* [ci skip] restore a newline

* [ci skip] remove microchange to refactor later

* reword a comment

* remove heapify methods from queue class

* doc: update docstrings for dt, rf, and et regressors

* doc: revert incorrect spacing to shorten diff

* convert get_median to return value directly

* [ci skip] remove accidental whitespace

* remove extraneous unpacking of values

* style: misc changes to identifiers

* add docstrings and more informative variable identifiers

* [ci skip] add trivial comments to recythonize

* remove trivial comments for recythonizing

* force recythonization for real this time

* remove trivial comments for recythonization

* rfc: harmonize arg. names and remove unnecessary checks

* convert allocations to safe_realloc

* fix bug in weighted case and add tests for MAE

* change all medians to DOUBLE_t

* add loginc allocate mediancalculators once, and reset otherwise

* misc style fixes

* modify cinit of regressioncriterion to take n_samples

* add MAE formula and force rebuild bc. travis was down

* add criterion parameter to gradient boosting and add forest tests

* add entries to what's new

TomDLT added a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016

[MRG+3] Add mean absolute error splitting criterion to DecisionTreeRe…
…gressor (#6667)

* feature: add initial node_value method

* testing code for node_impurity and node_value

This code runs into 'Bus Error: 10' at node_value final assignment.

* fix: node_value now correctly calculating weighted median for sorted data.

Still need to change the code to work with unsorted data.

* fix: node_value now correctly calculates median regardless of initial order

* fix: correct bug in calculating median when taking midpoint is necessary

* feature: add initial version of children_impurity

* feature: refactor median calculation into one function

* fix: fix use of DOUBLE_t vs double

* feature: move helper functions to _utils.pyx, fix mismatched pointer type

* fix: fix some bugs in children_impurity method

* push a debug version to try to solve segfault

* push latest changes, segfault probably happening bc of something in _utils.pyx

* fix: fix segfault in median calculation and remove excessive logging

* chore: revert some misc spacing changes I accidentally made

* chore: one last spacing fix in _splitter.pyx

* feature: don't calculate weighted median if no weights are passed in

* remove extraneous logging statement

* fix: fix children impurity calculation

* fix: fix bug with children impurity not being initally set to 0

* fix: hacky fix for a float accuracy error

* fix: incorrect type cast in median array generation for node_impurity

* slightly tweak node_impurity function

* fix: be more explicit with casts

* feature: revert cosmetic changes and free temporary arrays

* fix: only free weight array in median calcuation if it was created

* style: remove extraneous newline / trigger CI build

* style: remove extraneous 0 from range

* feature: save sorts within a node to speed it up

* fix: move parts of dealloc to regression criterion

* chore: add comment to splitter to try to force recythonizing

* chore: add comment to _tree.pyx to try to force recythonizing

* chore: add empty comment to gradient boosting to force recythonizing

* fix: fix bug in weighted median

* try moving sorted values to a class variable

* feature: refactor criterion to sort once initially, then draw all samples from this sorted data

* style: remove extraneous parens from if condition

* implement median-heap method for calculating impurity

* style: remove extra line

* style: fix inadvertent cosmetic changes; i'll address some of these in a separate PR

* feature: change minmaxheap to internally use sorted arrays

* refactored MAE and push to share work

* fix errors wrt median insertion case

* spurious comment to force recythonization

* general code cleanup

* fix typo in _tree.pyx

* removed some extraneous comments

* [ci skip] remove earlier microchanges

* [ci skip] remove change to priorityheap

* [ci skip] fix indentation

* [ci skip] fix class-specific issues with heaps

* [ci skip] restore a newline

* [ci skip] remove microchange to refactor later

* reword a comment

* remove heapify methods from queue class

* doc: update docstrings for dt, rf, and et regressors

* doc: revert incorrect spacing to shorten diff

* convert get_median to return value directly

* [ci skip] remove accidental whitespace

* remove extraneous unpacking of values

* style: misc changes to identifiers

* add docstrings and more informative variable identifiers

* [ci skip] add trivial comments to recythonize

* remove trivial comments for recythonizing

* force recythonization for real this time

* remove trivial comments for recythonization

* rfc: harmonize arg. names and remove unnecessary checks

* convert allocations to safe_realloc

* fix bug in weighted case and add tests for MAE

* change all medians to DOUBLE_t

* add loginc allocate mediancalculators once, and reset otherwise

* misc style fixes

* modify cinit of regressioncriterion to take n_samples

* add MAE formula and force rebuild bc. travis was down

* add criterion parameter to gradient boosting and add forest tests

* add entries to what's new
@m0uH

This comment has been minimized.

Show comment
Hide comment
@m0uH

m0uH Dec 25, 2016

Hey guys, I hope this is the right place for my concern.

Since I'm needing the MAPE (mean absolute percentage error) and am willing to implement it, I need some clarification on the implementation of the MAE criteria. Am I correct that the MAE is implemented using the median and therefore should be called median_absolute_error instead of mean_absolute_error? There are also two different classes for mean_abs... and median_abs... in the sklearn.metrics module.

Also I think this class calculates 1/n(\sum_i |Y_i - median(Y)|), shouldn't it be median(|Y_i - median(Y)|)?

Perhaps I should then focus on a new class based on RegressionCriterion for implementing MAPE ((1 / n)*(\sum_i (|y_i - f_i| / |f_i|)))? Or is this MAE just implemented with f_i = median(Y). Am I correct that for the MSE f_i equals to mean(Y) in 1/n(\sum_i(y_i - f_i)**2)?

I hope I didn't misunderstood something here.

m0uH commented Dec 25, 2016

Hey guys, I hope this is the right place for my concern.

Since I'm needing the MAPE (mean absolute percentage error) and am willing to implement it, I need some clarification on the implementation of the MAE criteria. Am I correct that the MAE is implemented using the median and therefore should be called median_absolute_error instead of mean_absolute_error? There are also two different classes for mean_abs... and median_abs... in the sklearn.metrics module.

Also I think this class calculates 1/n(\sum_i |Y_i - median(Y)|), shouldn't it be median(|Y_i - median(Y)|)?

Perhaps I should then focus on a new class based on RegressionCriterion for implementing MAPE ((1 / n)*(\sum_i (|y_i - f_i| / |f_i|)))? Or is this MAE just implemented with f_i = median(Y). Am I correct that for the MSE f_i equals to mean(Y) in 1/n(\sum_i(y_i - f_i)**2)?

I hope I didn't misunderstood something here.

@JohnStott

This comment has been minimized.

Show comment
Hide comment
@JohnStott

JohnStott Mar 2, 2017

Contributor

I have the same question as m0uH, I don't understand why median comes into the equation when calculating "mean absolute error"? I thought it should be implemented much the same as MSE but replacing the power 2 for an absolute...? Would really appreciate if somebody could explain why median is used?

Contributor

JohnStott commented Mar 2, 2017

I have the same question as m0uH, I don't understand why median comes into the equation when calculating "mean absolute error"? I thought it should be implemented much the same as MSE but replacing the power 2 for an absolute...? Would really appreciate if somebody could explain why median is used?

@jmschrei

This comment has been minimized.

Show comment
Hide comment
@jmschrei

jmschrei Mar 2, 2017

Member

The MAE of an array is minimized by the median of an array. This means that if you have an array and want to choose a single value which minimizes the MAE between that value and the array, it is always the median. In contrast, the MSE of an array is minimized by the mean of an array.

Member

jmschrei commented Mar 2, 2017

The MAE of an array is minimized by the median of an array. This means that if you have an array and want to choose a single value which minimizes the MAE between that value and the array, it is always the median. In contrast, the MSE of an array is minimized by the mean of an array.

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Mar 3, 2017

Member

Great explanation. Can we add it to the docs?

Member

jnothman commented Mar 3, 2017

Great explanation. Can we add it to the docs?

@JohnStott

This comment has been minimized.

Show comment
Hide comment
@JohnStott

JohnStott Mar 3, 2017

Contributor

Thank you jmschrei, this is difficult for me to understand, but I think I now get it with your help. Please could you confirm the following is correct:

if I get the mean absolute error and the median absolute error of the following array:
1, 2, 3, 5, 9, 10, 20
I get:
mean absolute error = 0
median absolute error = 15

we can't use the mean based absolute error because it will always be 0 i.e., splits would all end up with mean absolute error of 0 and also the absolute error across the parent would be 0 too, thus we wouldn't end up with any splits! Is this correct?

Contributor

JohnStott commented Mar 3, 2017

Thank you jmschrei, this is difficult for me to understand, but I think I now get it with your help. Please could you confirm the following is correct:

if I get the mean absolute error and the median absolute error of the following array:
1, 2, 3, 5, 9, 10, 20
I get:
mean absolute error = 0
median absolute error = 15

we can't use the mean based absolute error because it will always be 0 i.e., splits would all end up with mean absolute error of 0 and also the absolute error across the parent would be 0 too, thus we wouldn't end up with any splits! Is this correct?

@jmschrei

This comment has been minimized.

Show comment
Hide comment
@jmschrei

jmschrei Mar 3, 2017

Member

Consider the following:

def mse(x, y):
    return ((x-y) ** 2).mean()

def mae(x, y):
    return numpy.abs(x-y).mean()

a = numpy.array([1, 2, 5, 10, 20])

a_mean = a.mean()
a_median = numpy.median(a)

print mse(a, a_mean), mae(a, a_mean)        # 48.24 5.92
print mse(a, a_median), mae(a, a_median)    # 55.00 5.40

I am not sure how you got a mean absolute error of 0. In this example you can see that the mean gives a smaller MSE than the median, but that the median gives a smaller MAE than the mean. I hope that helps.

Member

jmschrei commented Mar 3, 2017

Consider the following:

def mse(x, y):
    return ((x-y) ** 2).mean()

def mae(x, y):
    return numpy.abs(x-y).mean()

a = numpy.array([1, 2, 5, 10, 20])

a_mean = a.mean()
a_median = numpy.median(a)

print mse(a, a_mean), mae(a, a_mean)        # 48.24 5.92
print mse(a, a_median), mae(a, a_median)    # 55.00 5.40

I am not sure how you got a mean absolute error of 0. In this example you can see that the mean gives a smaller MSE than the median, but that the median gives a smaller MAE than the mean. I hope that helps.

@JohnStott

This comment has been minimized.

Show comment
Hide comment
@JohnStott

JohnStott Mar 3, 2017

Contributor

My apologies, in my rush I had forgotten the most important step - taking the absolute!... :-s

Thank you for your example.

I have tried numerous example arrays through your mae function and it seems that the median does in fact always appear to produce the lower value error... I assume therefore that the error will always be smaller for the absolute error when using the median rather than the mean?

I am probably being thick (like earlier haha), but if my above assumption is true, why does this make the median a better candidate to use than the mean when calculating absolute error? Hopefully, I have got it right this time by saying..., it is because a smaller error means it has fit the data better!?

(I am wondering about the intricacies because calculating a "mean" based absolute error would be so much faster to run - but pointless if far inferior?... and/or whether a "_criterion.pyx -> proxy_impurity_improvement" could be loosely based around this?)

Contributor

JohnStott commented Mar 3, 2017

My apologies, in my rush I had forgotten the most important step - taking the absolute!... :-s

Thank you for your example.

I have tried numerous example arrays through your mae function and it seems that the median does in fact always appear to produce the lower value error... I assume therefore that the error will always be smaller for the absolute error when using the median rather than the mean?

I am probably being thick (like earlier haha), but if my above assumption is true, why does this make the median a better candidate to use than the mean when calculating absolute error? Hopefully, I have got it right this time by saying..., it is because a smaller error means it has fit the data better!?

(I am wondering about the intricacies because calculating a "mean" based absolute error would be so much faster to run - but pointless if far inferior?... and/or whether a "_criterion.pyx -> proxy_impurity_improvement" could be loosely based around this?)

@nelson-liu

This comment has been minimized.

Show comment
Hide comment
@nelson-liu

nelson-liu Mar 3, 2017

Contributor

I assume therefore that the error will always be smaller for the absolute error when using the median rather than the mean?

Not @jmschrei , but yes this is correct (well technically they could be equal). Hence why we use it :)

Hopefully, I have got it right this time by saying..., it is because a smaller error means it has fit the data better!?

Indeed, the goal of the tree is to grow in a way such that it minimizes the criterion (in this case, MAE) on the train set.

I am wondering about the intricacies because calculating a "mean" based absolute error would be so much faster to run - but pointless if far inferior

yes, calculating MAE with the "mean" does not actually minimize the MAE. Since our goal is to minimize the MAE, we thus use the median over the mean. In other words, calculating a "mean" based mean absolute error would be much faster, but not guarantee good results because the mean does not minimize the criterion.

Contributor

nelson-liu commented Mar 3, 2017

I assume therefore that the error will always be smaller for the absolute error when using the median rather than the mean?

Not @jmschrei , but yes this is correct (well technically they could be equal). Hence why we use it :)

Hopefully, I have got it right this time by saying..., it is because a smaller error means it has fit the data better!?

Indeed, the goal of the tree is to grow in a way such that it minimizes the criterion (in this case, MAE) on the train set.

I am wondering about the intricacies because calculating a "mean" based absolute error would be so much faster to run - but pointless if far inferior

yes, calculating MAE with the "mean" does not actually minimize the MAE. Since our goal is to minimize the MAE, we thus use the median over the mean. In other words, calculating a "mean" based mean absolute error would be much faster, but not guarantee good results because the mean does not minimize the criterion.

@JohnStott

This comment has been minimized.

Show comment
Hide comment
@JohnStott

JohnStott Mar 4, 2017

Contributor

Thank you both for your explanations and time with this. I am very appreciative 👍 . I apologise for slightly hijacking this thread but hope that it helps others who come across this with the same questions?

It seems the lack of standard terminology in this domain is where I was mostly tripped up (and others looking through the various related issue threads here). I was incorrectly assuming that "Mean" Absolute Error meant that the value deducted from X, before taking the absolute, was the Mean...I thus thought there was such a thing as a "Median Absolute Error", however, "Mean Absolute Error" actually covers Mean, Mode and Median (or any other measure of central tendency). The "Mean" in MAE actually refers to the fact we divide by N irrespective of the measure of central tendency i.e., as expressed earlier, in bold, is what the "mean" refers to in MAE:

def mse(x, y):
return ((x-y) ** 2).mean()

def mae(x, y):
return numpy.abs(x-y).mean()

It is perhaps confusing because the most widely used metric: MSE will predominantly use mean as the central tendency; this is where the incorrect assumption can arise (well,... in my case at least).

So... It just so happens, the MAE used, when applied here (scikit-learn DecisionTrees), is the "Median" central tendency metric!

The other point of confusion is the usual usage of MSE/MAE in learning algorithms, these are usually applied to:
X - y
where y is the target variable. However, when we are dealing with decision trees, the y is not the "target" as we are looking at a single variable! It is the central point from which we wish to gauge variance/sd and thus obtain splits to maximise information gain.

For future reference for myself/others, I found the following links which cover most of my confusion with regards to MAE, variance and Decision Trees:

https://en.wikipedia.org/wiki/Average_absolute_deviation
http://www.saedsayad.com/decision_tree_reg.htm

[please correct me if any of the above is wrong :) ]

Contributor

JohnStott commented Mar 4, 2017

Thank you both for your explanations and time with this. I am very appreciative 👍 . I apologise for slightly hijacking this thread but hope that it helps others who come across this with the same questions?

It seems the lack of standard terminology in this domain is where I was mostly tripped up (and others looking through the various related issue threads here). I was incorrectly assuming that "Mean" Absolute Error meant that the value deducted from X, before taking the absolute, was the Mean...I thus thought there was such a thing as a "Median Absolute Error", however, "Mean Absolute Error" actually covers Mean, Mode and Median (or any other measure of central tendency). The "Mean" in MAE actually refers to the fact we divide by N irrespective of the measure of central tendency i.e., as expressed earlier, in bold, is what the "mean" refers to in MAE:

def mse(x, y):
return ((x-y) ** 2).mean()

def mae(x, y):
return numpy.abs(x-y).mean()

It is perhaps confusing because the most widely used metric: MSE will predominantly use mean as the central tendency; this is where the incorrect assumption can arise (well,... in my case at least).

So... It just so happens, the MAE used, when applied here (scikit-learn DecisionTrees), is the "Median" central tendency metric!

The other point of confusion is the usual usage of MSE/MAE in learning algorithms, these are usually applied to:
X - y
where y is the target variable. However, when we are dealing with decision trees, the y is not the "target" as we are looking at a single variable! It is the central point from which we wish to gauge variance/sd and thus obtain splits to maximise information gain.

For future reference for myself/others, I found the following links which cover most of my confusion with regards to MAE, variance and Decision Trees:

https://en.wikipedia.org/wiki/Average_absolute_deviation
http://www.saedsayad.com/decision_tree_reg.htm

[please correct me if any of the above is wrong :) ]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment