[MRG+2] Make `enet_coordinate_descent_gram` support fused types #7218

Merged
merged 2 commits into from Aug 29, 2016

Conversation

Projects
None yet
5 participants
@yenchenlin
Contributor

yenchenlin commented Aug 21, 2016

For the ElasticNetCV (and related classes), the precompute is set by default to "auto", which means for benefits in memory when n_samples > n_features, we need to make enet_coordinate_descent_gram handle float32 dtypes as well.

This PR makes enet_coordinate_descent_gram support cython fused types.

**EDIT 8/26
Here is the profiling results when fitting float32 data:

  • This branch
    32
  • Master
    64
@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 21, 2016

Contributor

Note that currently, it is based on #6913

Contributor

yenchenlin commented Aug 21, 2016

Note that currently, it is based on #6913

@yenchenlin yenchenlin changed the title from [MRG] Make Cd enet_coordinate_descent_gram support fused types to [MRG] Make `enet_coordinate_descent_gram` support fused types Aug 21, 2016

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Aug 21, 2016

Member
Member

agramfort commented Aug 21, 2016

@yenchenlin yenchenlin changed the title from [MRG] Make `enet_coordinate_descent_gram` support fused types to [WIP] Make `enet_coordinate_descent_gram` support fused types Aug 21, 2016

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 21, 2016

Contributor

@agramfort Ah sorry 😄
thanks for the remind!

I submit it in advance to make sure CI is all good.

Contributor

yenchenlin commented Aug 21, 2016

@agramfort Ah sorry 😄
thanks for the remind!

I submit it in advance to make sure CI is all good.

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 26, 2016

Contributor

Here is the profiling test script, which makes sure enet_coordinate_descent_gram is called:

import numpy as np
from sklearn.linear_model.coordinate_descent import ElasticNet
from sys import argv
@profile
def fit_est():
    clf.fit(X, y)


np.random.seed(5)
X = np.random.rand(2000000, 40)
X = np.float32(X)
y = np.random.rand(2000000)
y = np.float32(y)
T = np.random.rand(5, 40)
T = np.float32(T)

if argv[1] == "64":
    X = np.float64(X)
    y = np.float64(y)
    T = np.float64(T)

Gram = X.T.dot(X)

clf = ElasticNet(alpha=1e-7, l1_ratio=1.0, precompute=Gram)
fit_est()
Contributor

yenchenlin commented Aug 26, 2016

Here is the profiling test script, which makes sure enet_coordinate_descent_gram is called:

import numpy as np
from sklearn.linear_model.coordinate_descent import ElasticNet
from sys import argv
@profile
def fit_est():
    clf.fit(X, y)


np.random.seed(5)
X = np.random.rand(2000000, 40)
X = np.float32(X)
y = np.random.rand(2000000)
y = np.float32(y)
T = np.random.rand(5, 40)
T = np.float32(T)

if argv[1] == "64":
    X = np.float64(X)
    y = np.float64(y)
    T = np.float64(T)

Gram = X.T.dot(X)

clf = ElasticNet(alpha=1e-7, l1_ratio=1.0, precompute=Gram)
fit_est()

@yenchenlin yenchenlin changed the title from [WIP] Make `enet_coordinate_descent_gram` support fused types to [MRG] Make `enet_coordinate_descent_gram` support fused types Aug 26, 2016

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Aug 27, 2016

Member

The code looks good to me but I am not sure I understand the output of the memory profiling sessions. Can you please extend it to include the case where the model is fitted without Gram matrix precomputation?

Member

ogrisel commented Aug 27, 2016

The code looks good to me but I am not sure I understand the output of the memory profiling sessions. Can you please extend it to include the case where the model is fitted without Gram matrix precomputation?

+ np.ndarray[floating, ndim=2, mode='c'] Q,
+ np.ndarray[floating, ndim=1, mode='c'] q,
+ np.ndarray[floating, ndim=1] y,
+ int max_iter, floating tol, object rng,

This comment has been minimized.

@ogrisel

ogrisel Aug 27, 2016

Member

The scalars alpha, beta and tol should be left to double as the default casting should work.

@ogrisel

ogrisel Aug 27, 2016

Member

The scalars alpha, beta and tol should be left to double as the default casting should work.

This comment has been minimized.

@yenchenlin

yenchenlin Aug 28, 2016

Contributor

Doing this will introduce some implementation problems since these variables have to add/multply with some fused types variables such as w, which may be of type float32.

@yenchenlin

yenchenlin Aug 28, 2016

Contributor

Doing this will introduce some implementation problems since these variables have to add/multply with some fused types variables such as w, which may be of type float32.

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 28, 2016

Contributor

@ogrisel sorry I did not make it clear, what I'm comparing here is actually the memory profiling results of This branch and Master.

the model is fitted without Gram matrix precomputation

This case has been done in #6913 .

Contributor

yenchenlin commented Aug 28, 2016

@ogrisel sorry I did not make it clear, what I'm comparing here is actually the memory profiling results of This branch and Master.

the model is fitted without Gram matrix precomputation

This case has been done in #6913 .

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 28, 2016

Contributor

Would @jnothman or @MechCoder please give it a review?
Also, should I include enet_coordinate_descent_multi_task() into this PR, or open another?

Contributor

yenchenlin commented Aug 28, 2016

Would @jnothman or @MechCoder please give it a review?
Also, should I include enet_coordinate_descent_multi_task() into this PR, or open another?

@jnothman

This comment has been minimized.

Show comment
Hide comment
@jnothman

jnothman Aug 28, 2016

Member

LGTM!

Member

jnothman commented Aug 28, 2016

LGTM!

@jnothman jnothman changed the title from [MRG] Make `enet_coordinate_descent_gram` support fused types to [MRG+1] Make `enet_coordinate_descent_gram` support fused types Aug 28, 2016

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 29, 2016

Contributor

ping @MechCoder please 😛

Contributor

yenchenlin commented Aug 29, 2016

ping @MechCoder please 😛

@agramfort agramfort changed the title from [MRG+1] Make `enet_coordinate_descent_gram` support fused types to [MRG+2] Make `enet_coordinate_descent_gram` support fused types Aug 29, 2016

@agramfort

This comment has been minimized.

Show comment
Hide comment
@agramfort

agramfort Aug 29, 2016

Member

LGTM

@MechCoder merge if you're also happy

Member

agramfort commented Aug 29, 2016

LGTM

@MechCoder merge if you're also happy

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 29, 2016

Contributor

@agramfort thanks a lot.

Contributor

yenchenlin commented Aug 29, 2016

@agramfort thanks a lot.

@ogrisel

This comment has been minimized.

Show comment
Hide comment
@ogrisel

ogrisel Aug 29, 2016

Member

LGTM as well. Merging. Thanks @yenchenlin!

Member

ogrisel commented Aug 29, 2016

LGTM as well. Merging. Thanks @yenchenlin!

@ogrisel ogrisel merged commit 845e702 into scikit-learn:master Aug 29, 2016

3 checks passed

ci/circleci Your tests passed on CircleCI!
Details
continuous-integration/appveyor/pr AppVeyor build succeeded
Details
continuous-integration/travis-ci/pr The Travis CI build passed
Details
@MechCoder

This comment has been minimized.

Show comment
Hide comment
@MechCoder

MechCoder Aug 30, 2016

Member

Sorry for the delay! Can you also fix ElasticNetCV and LassoCV to not implicitly convert float32 dtypes?

Member

MechCoder commented Aug 30, 2016

Sorry for the delay! Can you also fix ElasticNetCV and LassoCV to not implicitly convert float32 dtypes?

@yenchenlin

This comment has been minimized.

Show comment
Hide comment
@yenchenlin

yenchenlin Aug 30, 2016

Contributor

@MechCoder Sure!

Contributor

yenchenlin commented Aug 30, 2016

@MechCoder Sure!

TomDLT added a commit to TomDLT/scikit-learn that referenced this pull request Oct 3, 2016

paulha added a commit to paulha/scikit-learn that referenced this pull request Aug 19, 2017

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment