WIP: Implement ROC-SVM linear pairwise ranking loss with SGD #1386

Open
wants to merge 27 commits into
from

8 participants

@coreylynch

see #1061

I'm taking a stab at porting SophiaML's ranking sgd svm to sklearn.

Summary of Additions

sgd_fast.pyx
  • Added a ranking_sgd function. Similar to plain_sgd, but implements the outer loop found in sophia-ml-methods.cc's StochasticRocLoop & SingleSgdSvmRankStep Sophia source
utils.seq_dataset.pyx:
  • Added a PairwiseArrayDataset class
  • In the init, I create an index of positives and negatives for fast sampling of disagreeing pairs similar to the StochasticRocLoop in sophia-ml-methods.cc
  • The .next() returns a random pair of examples with disagreeing labels. See the 'Indexed Sampling' section of http://www.eecs.tufts.edu/~dsculley/papers/large-scale-rank.pdf; they point out that indexing gave faster sampling when D fit in memory.
weight_vector.pyx
  • Added "dot_on_difference" method, following the example of Sophia ML's SingleSgdSvmRankStep and sf-weight-vector.cc implementation.
stochastic_gradient.py
  • Added a 'roc_pairwise_ranking' loss option to the SGDClassifier class
  • Added a 'ranking' parameter to _make_dataset that defaults to False, but is set to True when the loss function is 'roc_pairwise_ranking'. Used to create a pairwise dataset in seq_dataset.pyx
Problems
  • I tried running the current code on an unbalanced dataset (using this example's data) and I'm not getting the performance I had hoped for.

Here is the code I ran:

import numpy as np
import pylab as pl
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import confusion_matrix


if __name__=="__main__":
    n_samples_1 = 10000
    n_samples_2 = 100
    X = np.r_[1.5 * np.random.randn(n_samples_1, 2),
              0.5 * np.random.randn(n_samples_2, 2) + [2, 2]]
    y = np.array([0] * (n_samples_1) + [1] * (n_samples_2), dtype=np.float64)
    idx = np.arange(y.shape[0])
    np.random.shuffle(idx)
    X = X[idx]
    y = y[idx]
    mean = X.mean(axis=0)
    std = X.std(axis=0)
    X = (X - mean) / std

    for clf, name in ((SGDClassifier(n_iter=100, alpha=0.01), "plain sgd"),
                        (SGDClassifier(n_iter=100, alpha=0.01, class_weight={1: 10}),("weighted sgd")),
                        (SGDClassifier(n_iter=100, alpha=0.01, loss='roc_pairwise_ranking'), ("pairwise sgd"))):
        clf.fit(X,y)
        print clf
        print "SCORE: " + str(clf.score(X,y))
        print "CONFUSION MATRIX: "
        print confusion_matrix(y,clf.predict(X))
        print 80*'='

... which outputs

SGDClassifier(alpha=0.01, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=hinge, n_iter=100, n_jobs=1, penalty=l2, power_t=0.5, rho=None,
       seed=0, shuffle=False, verbose=0, warm_start=False)
SCORE: 0.990099009901
CONFUSION MATRIX: 
[[10000     0]
 [  100     0]]
================================================================================
SGDClassifier(alpha=0.01, class_weight={1: 10}, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=hinge, n_iter=100, n_jobs=1, penalty=l2, power_t=0.5, rho=None,
       seed=0, shuffle=False, verbose=0, warm_start=False)
SCORE: 0.968415841584
CONFUSION MATRIX: 
[[9731  269]
 [  50   50]]
================================================================================
SGDClassifier(alpha=0.01, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=roc_pairwise_ranking, n_iter=100, n_jobs=1, penalty=l2,
       power_t=0.5, rho=None, seed=0, shuffle=False, verbose=0,
       warm_start=False)
SCORE: 0.514257425743
CONFUSION MATRIX: 
[[5094 4906]
 [   0  100]]
================================================================================

Any suggestions for code cleanup would be greatly appreciated!

@pprett
scikit-learn member

awesome - thanks - looking forward to review

@agramfort
scikit-learn member

as a benchmark you can look at @fabianp cython bindings of sofia-ml:

https://github.com/fabianp/minirank/tree/master/minirank

I am very enthusiast about this PR !

@mblondel
scikit-learn member

Supporting passive-aggressive would be nice.

@pprett
scikit-learn member
@pprett
scikit-learn member

should expose the functionality as a separate class (e.g. RankingSGD - I'm not that good in finding names...) instead of a loss function loss='roc_ranking' ?

I'll look into the feature vector abstraction in the evening - this would allow use to support both dense and sparse datasets; as well as merge plain_sgd and ranking_sgd.

@pprett
scikit-learn member

According to D. Sculley "Our efforts to include pairs with ties (where ya = yb) yielded no additional benefit" -> I think we can skip sampling ties (I did see in the code that y could be 0)

@amueller
scikit-learn member

+1 for an additional class name for discoverability (that is not a word, is it?).

@pprett
scikit-learn member

@coreylynch I did some modifications to your PR:

Basically, I introduced a new FeatureVector abstraction that allows iteration over coefficients i.e. (index, value) pairs. PairwiseFeatureVector, a sub class of FeatureVector, gets two FeatureVectors as inputs and iterates over both of them (negates the second one).

This allows me to get rid of ranking_sgd and WeigtVector.dot_differences and means that ranking works with both sparse and dense data as well with all our current loss functions but it comes at a price... roughly a factor of two runtime performance degradation::

              master             FV
RCV1*      1.21s            2.74s
CovType    0.39s            0.70s

*RCV1-ccat; n_iter=5

I've tried hard to make it faster but its tough... I'd really appreciated if somebody else could have a look at the code - you can find it here:
https://github.com/pprett/scikit-learn/tree/sgd-ranksvm-pprett

the relevant files are seq_dataset.pyx, weight_vector.pyx, and sgd_fast.pyx.

I don't think that I can live with a factor of 2, thus, I'd rather prefer @coreylynch solution at the moment (sparse support would be great though).

I also tested the results against Fabian's RankSVM - results are basically the same but way faster - great work @coreylynch !

@coreylynch

@pprett The FeatureVector abstraction looks nice. If you wouldn't mind, I'd be interested in taking a look at your benchmarking code for both the RCV1 dataset and Fabian's RankSVM.

@pprett
scikit-learn member

@coreylynch you can find the benchmark against the RankSVM implemenation of @fabianp and @agramfort here: https://gist.github.com/4150478

The RCV1 benchmark requires the RCV1 input files of Leon Buttou's sgd project: http://leon.bottou.org/projects/sgd (you need to download the lyrl2004_tokens files and run the convert.py script). I used the following benchmark code: https://gist.github.com/4150519 .

@agramfort
scikit-learn member
@coreylynch

@pprett @agramfort I added the Kendall correlation coefficient to the benchmark script and included minirank's sofia-ml binding classifier. Code is here: https://gist.github.com/4150976

The results of the script are:

SGDClassifier(alpha=0.01, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=hinge, n_iter=100, n_jobs=1, penalty=l2, power_t=0.5, rho=None,
       seed=0, shuffle=False, verbose=0, warm_start=False)
ACC: 0.9900
AUC: 0.5000
CONFUSION MATRIX: 
[[9999    1]
 [ 100    0]]
Kendall Tau: 0.1300
================================================================================
SGDClassifier(alpha=0.01, class_weight={1: 10}, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=hinge, n_iter=100, n_jobs=1, penalty=l2, power_t=0.5, rho=None,
       seed=0, shuffle=False, verbose=0, warm_start=False)
ACC: 0.9711
AUC: 0.7181
CONFUSION MATRIX: 
[[9762  238]
 [  54   46]]
Kendall Tau: 0.1299
================================================================================
SGDClassifier(alpha=0.01, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, l1_ratio=0.15, learning_rate=optimal,
       loss=roc_pairwise_ranking, n_iter=1000, n_jobs=1, penalty=l2,
       power_t=0.5, rho=None, seed=0, shuffle=False, verbose=0,
       warm_start=False)
ACC: 0.5144
AUC: 0.7548
CONFUSION MATRIX: 
[[5095 4905]
 [   0  100]]
Kendall Tau: 0.1299
================================================================================

RankSVM(alpha=0.01, class_weight=None, epsilon=0.1, eta0=0.0,
    fit_intercept=True, l1_ratio=0.15, learning_rate=optimal, loss=hinge,
    n_iter=100, n_jobs=1, penalty=l2, power_t=0.5, rho=None, seed=0,
    shuffle=False, verbose=0, warm_start=False)
ACC: 0.5148
AUC: 0.7550
CONFUSION MATRIX: 
[[5099 4901]
 [   0  100]]
Kendall Tau: 0.1299
================================================================================
RankSVM(alpha=0.01, max_iter=100, model=rank)
ACC: 0.5163
AUC: 0.7558
CONFUSION MATRIX: 
[[5115 4885]
 [   0  100]]
Kendall Tau: 0.1299
================================================================================

I also created a separate class called RankingSGD that inherits from SGDClassifier with a default roc_pairwise_ranking loss function. This additional class has a rank function that can be used to rank a test set using a trained ranking model. Feedback still very welcome! Thanks

@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/__init__.py
@@ -18,7 +18,7 @@
lasso_path, enet_path, MultiTaskLasso, \
MultiTaskElasticNet
from .sgd_fast import Hinge, Log, ModifiedHuber, SquaredLoss, Huber
-from .stochastic_gradient import SGDClassifier, SGDRegressor
+from .stochastic_gradient import SGDClassifier, RankingSGD, SGDRegressor
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

To be consistent it should be called SGDRanking

I like the idea to have a separate estimator as the default score can now be kendall tau. Otherwise the score func depends on the loss. Just thinking out loud...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
+ fit_intercept=fit_intercept,
+ n_iter=n_iter, shuffle=shuffle,
+ verbose=verbose, epsilon=epsilon,
+ seed=seed, rho=rho,
+ learning_rate=learning_rate,
+ eta0=eta0, power_t=power_t,
+ warm_start=warm_start)
+ self.class_weight = class_weight
+ self.classes_ = None
+ self.n_jobs = int(n_jobs)
+
+ def rank(self,X):
+ order = np.argsort(np.dot(X,self.coef_[0]))
+ order_inv = np.zeros_like(order)
+ order_inv[order] = np.arange(len(order))
+ return order_inv
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

can you add a score method being kendall-tau?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
@@ -676,6 +696,34 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
est.power_t, est.t_, intercept_decay)
+class RankingSGD(SGDClassifier):
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

you'll need to add docstrings.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coreylynch coreylynch 1. Added docstrings
2. Forced loss in ranking to be roc_pairwise_ranking
3. Changed RankingSGD to SGDRanking
4. Implemented a 'score' function with Kendall Tau
6478381
@coreylynch

@agramfort I added docstrings and the score method. Not sure what to put under the See Also section of the docs. Happy to hear any rewording suggestions.

@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
@@ -676,6 +695,190 @@ def fit_binary(est, i, X, y, n_iter, pos_weight, neg_weight,
est.power_t, est.t_, intercept_decay)
+class SGDRanking(SGDClassifier):
+ """Ranking model fitted by minimizing a regularized empirical loss with SGD
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

pep 257

docstring should be a first short sentence on one line
then the main text block

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on the diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
+ Parameters
+ ----------
+ X : array-like, shape = [n_samples, n_features]
+ Test set.
+
+ Returns
+ -------
+ order_inv : array-like, shape = [n_samples]
+ """
+ order = np.argsort(np.dot(X, self.coef_[0]))
+ order_inv = np.zeros_like(order)
+ order_inv[order] = np.arange(len(order))
+ return order_inv
+
+ def score(self, X, y):
+ """
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

pep 257 here too

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
+ Test set.
+
+ Returns
+ -------
+ order_inv : array-like, shape = [n_samples]
+ """
+ order = np.argsort(np.dot(X, self.coef_[0]))
+ order_inv = np.zeros_like(order)
+ order_inv[order] = np.arange(len(order))
+ return order_inv
+
+ def score(self, X, y):
+ """
+ Returns the Kendall's tau correlation coefficient, which is used in the
+ ranking literature (e.g T. Joachims, Optimizing Search Engines using
+ Clickthrough Data, KDD 2002) to compare the ordering of a model to a
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

the reference should appear outside of the main doc block of text.

"""Returns the Kendall's tau correlation coefficient

The Kendall's tau correlation is used to evaluate the quality of a
ranking prediction. It compares the ordering predicted by the model to a
given ordering.

Reference
T. Joachims, Optimizing Search Engines using Clickthrough Data, KDD 2002
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort agramfort commented on an outdated diff Nov 27, 2012
sklearn/linear_model/stochastic_gradient.py
+ super(SGDRanking, self).__init__(loss=loss, penalty=penalty,
+ alpha=alpha, l1_ratio=l1_ratio,
+ fit_intercept=fit_intercept,
+ n_iter=n_iter, shuffle=shuffle,
+ verbose=verbose, epsilon=epsilon,
+ seed=seed, rho=rho,
+ learning_rate=learning_rate,
+ eta0=eta0, power_t=power_t,
+ warm_start=warm_start)
+ self.class_weight = class_weight
+ self.classes_ = None
+ self.n_jobs = int(n_jobs)
+ if self.loss != "roc_pairwise_ranking":
+ raise ValueError("The loss %s is not supported. " % self.loss)
+
+ def rank(self, X):
@agramfort
scikit-learn member
agramfort added a line comment Nov 27, 2012

I feel this method is the predict and the decision_function method should be just np.dot(X, self.coef_)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@agramfort
scikit-learn member

a more general design open issue. Although pairwise ranking is common with a hinge loss it can work with any binary classification loss (log, hinge, huber etc.) That's why I feel the loss=roc_pairwise_ranking is SGDClassifier is not the best approach. SGDRanking should accept a loss being hinge,log,etc. and if you really want a pairwise loss exposed in SGDClassifier I would call it hinge_pairwise.

Second thing, Ranking is commonly done with a non-total order. Not all pairs can be formed. Typically only the pairs in a query are compared. With @fabianp we do it by passing an extra vector which contains the query_id. If the query_id is the same for two samples then they can form a pair. I think this feature is really import when you do ranking.

Last but not least we'll need some narrative doc and an example that illustrates the pros of the SGDRanking estimator.

what do you guys think?

@mblondel
scikit-learn member

a more general design open issue. Although pairwise ranking is common with a hinge loss it can work with any binary classification loss (log, hinge, huber etc.) That's why I feel the loss=roc_pairwise_ranking is SGDClassifier is not the best approach. SGDRanking should accept a loss being hinge,log,etc. and if you really want a pairwise loss exposed in SGDClassifier I would call it hinge_pairwise.

+1 for supporting arbitrary losses. The loss parameter should not be used to enable / disable ranking.

Second thing, Ranking is commonly done with a non-total order. Not all pairs can be formed. Typically only the pairs in a query are compared. With @fabianp we do it by passing an extra vector which contains the query_id. If the query_id is the same for two samples then they can form a pair. I think this feature is really import when you do ranking.

+1

This will require to support different pair sampling schemes in the pairwise dataset. We can add a sampling option to SGDRanker (sampling="balanced" for ROC/AUC SVM and sampling="group" for Rank-SVM).

@fabianp
scikit-learn member

I feel the same about the query_id mentioned by agramfort, which is also present in sofia_ml.

coreylynch added some commits Nov 29, 2012
@coreylynch coreylynch 1. added sampling option to SGDRanking (init and next behavior in seq…
…_dataset.pyx now depend on

   sampling parameter
2. added query_id parameter
3. cleaned up .next() code in seq_dataset.pyx thanks to @pprett
adf3047
@coreylynch coreylynch refactored PairwiseArrayDataset into different classes f90a159
@coreylynch coreylynch 1. Added sampling option to SGDRanker
2. Added query_id parameter support
2. Implemented StochasticRankLoop, StochasticRocLoop
d6da34e
@coreylynch

@mblondel SGDRanking now supports arbitrary losses

@agramfort & @fabianp query_id is now supported in the fit() method, backed by cython implementations of sophia ml's StochasticRocLoop and StochasticRankLoop.

Thanks for taking a look!

@agramfort

query_id is missing

@agramfort

class_weight is missing

@agramfort

if it's deprecated don't add it to a new estimator :)

@agramfort

line too long

@agramfort
scikit-learn member

looks promising and I like the design/usability better. There might be some code duplication that we could avoid. But let's keep this for later.

Can you now add tests and an example that illustrates the pros of ROC / ranking with different sampling?

@pprett
scikit-learn member

@agramfort @coreylynch we should definitely support sparse inputs too; you should be able to copy&paste the necessary code from https://github.com/pprett/scikit-learn/tree/sgd-ranksvm-pprett

@coreylynch

@pprett sparse inputs are now supported.

@agramfort
scikit-learn member

line too long

@agramfort
scikit-learn member

line too long. Sorry to bug you...

@coreylynch

@agramfort No problem! I appreciate the review.

@ogrisel
scikit-learn member

Could you merge the current master into your branch?

@ogrisel ogrisel commented on an outdated diff Dec 4, 2012
sklearn/linear_model/sgd_fast.pyx
+ q_data_ptr = <DOUBLE *> q.data
+ cdef double u = 0.0
+
+ if penalty_type == L2:
+ rho = 1.0
+ elif penalty_type == L1:
+ rho = 0.0
+
+ eta = eta0
+
+ t_start = time()
+ for epoch in range(n_iter):
+ if verbose > 0:
+ print("-- Epoch %d" % (epoch + 1))
+ #if shuffle:
+ # dataset.shuffle(seed)
@ogrisel
scikit-learn member
ogrisel added a line comment Dec 4, 2012

Please delete those commented out lines if you don't need them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coreylynch

@ogrisel Sure. Do you prefer git merge master or git rebase master?

@GaelVaroquaux
scikit-learn member
coreylynch added some commits Dec 4, 2012
@coreylynch coreylynch Merge branch 'master' into roc-svm-pairwise-sgd
Conflicts:
	sklearn/linear_model/sgd_fast.c
	sklearn/linear_model/sgd_fast.pyx
	sklearn/linear_model/stochastic_gradient.py
4d4aaa2
@coreylynch coreylynch Took out comments fea82a0
@coreylynch coreylynch fixed seq_dataset.c 1d1580d
@coreylynch

After my merge, the Travis build fails in test_sgd.py on the test_sgd_multiclass_njobs test. When I tried reproducing on my machine, I got

PicklingError: Can't pickle <type 'sgd_fast.Hinge'>: it's not the same object as sgd_fast.Hinge

Strangely,

clf = SGDRanking(alpha=0.01, n_iter=20, n_jobs=2).fit(X2, Y2)

throws no error. I'll keep looking at this but I was curious if anyone had any insights.

@pprett
scikit-learn member

@coreylynch if you compile a cython module you need to re-compile all modules that import it (I guess you compiled seq_dataset.pyx; so you just need to re-compile sgd_fast.pyx and it should be fine)

@pprett
scikit-learn member

btw: I didn't manage to review your latest commits yet - I promise you that I'll do so in the next couple of days so that we merge this PR soon

@pprett pprett commented on an outdated diff Dec 10, 2012
sklearn/utils/seq_dataset.pyx
@@ -174,3 +185,376 @@ cdef class CSRDataset(SequentialDataset):
cdef void shuffle(self, seed):
np.random.RandomState(seed).shuffle(self.index)
+
+
+cdef class PairwiseDataset:
+ """Base class for datasets with sequential access to pairs."""
+ """
@pprett
scikit-learn member
pprett added a line comment Dec 10, 2012

just a minor thing: you can remove the opening and closing triple quotes above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@pprett pprett commented on an outdated diff Dec 10, 2012
sklearn/linear_model/stochastic_gradient.py
+ intercept_decay = SPARSE_INTERCEPT_DECAY
+ else:
+ dataset = PairwiseArrayDatasetRoc(X, y_i)
+ intercept_decay = 1.0
+ else:
+ if query_id is None:
+ query_id = np.ones(X.shape[0])
+ if sp.issparse(X):
+ dataset = PairwiseCSRDatasetRank(X.data, X.indptr, X.indices,
+ y_i, query_id)
+ intercept_decay = SPARSE_INTERCEPT_DECAY
+ else:
+ dataset = PairwiseArrayDatasetRank(X, y_i, query_id)
+ intercept_decay = 1.0
+ elif sp.issparse(X):
+ dataset = CSRDataset(X.data, X.indptr, X.indices, y_i)
@pprett
scikit-learn member
pprett added a line comment Dec 10, 2012

you forgot to pass sample_weight - tests do not pass

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@pprett pprett commented on an outdated diff Dec 10, 2012
sklearn/utils/seq_dataset.pyx
+ cdef DOUBLE group_id
+ cdef int y_range
+
+ # set vector a
+ for i in range(num_chances):
+ a_idx = rand() % self.n_samples
+ a_offset = self.X_indptr_ptr[a_idx]
+ a_data_ptr[0] = self.X_data_ptr + a_offset
+ a_ind_ptr[0] = self.X_indices_ptr + a_offset
+ y_a[0] = self.Y_data_ptr[a_idx]
+ nnz_a[0] = self.X_indptr_ptr[a_idx + 1] - a_offset
+ group_id = self.query_data_ptr[a_idx]
+ y_to_list = self.group_id_y_to_index[group_id]
+ y_range = self.group_id_y_to_count[group_id] - \
+ len(self.group_id_y_to_index[group_id][y_a[0]])
+ if (y_range==0):
@pprett
scikit-learn member
pprett added a line comment Dec 10, 2012

no brackets; could you add a comment that describes the semantics of y_range; you could drop the continue stmt if you negate the if expression and break if it evaluates to true.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@pprett pprett commented on the diff Dec 10, 2012
sklearn/utils/seq_dataset.pyx
+ a_ind_ptr[0] = self.X_indices_ptr + a_offset
+ y_a[0] = self.Y_data_ptr[a_idx]
+ nnz_a[0] = self.X_indptr_ptr[a_idx + 1] - a_offset
+ group_id = self.query_data_ptr[a_idx]
+ y_to_list = self.group_id_y_to_index[group_id]
+ y_range = self.group_id_y_to_count[group_id] - \
+ len(self.group_id_y_to_index[group_id][y_a[0]])
+ if (y_range==0):
+ continue
+ break
+
+ # set vector b
+ cdef unsigned int random_int = rand() % y_range
+ cdef int b_idx
+ cdef int b_offset
+ for b_y, idx_list in y_to_list.items():
@pprett
scikit-learn member
pprett added a line comment Dec 10, 2012

is it possible that the b_* are not properly set? e.g. what if random_int < len(idx_list) is never true?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@coreylynch coreylynch 1. Added sample_weight support
2. Removed unnecessary conditional from seq_dataset
3. Took out num_chances
93cc5f6
@coreylynch

@pprett The conditional was taken from sofia-ml’s StochasticRankLoop function, but I'm now thinking the random_int < len(idx_list) conditional statement is necessary. Here's my reasoning:

  1. group_id_y_to_count[group_id] is the count of all positive and negative y values in the given group_id
  2. len(group_id_y_to_index[group_id][y_a][0]]) is the number of examples with y_a’s label in the given group_id
  3. y_range is 1. - 2., or the number of examples without y_a’s label in the given group_id
  4. rand_int is a random integer modulo y_range
  5. y_to_list is a dictionary with y_values as keys and lists of indexes to examples with those values as values.
  6. We iterate over y_to_list, skipping the key value pair with the matching y_a's value, leaving us with the idx_list matching the y value opposite y_a. The length of this idx_list should be equivalent to y_range, as it is the number of examples without y_a’s label in the given group_id
  7. If y_range and len(idx) in 6. are the same number, then rand_int should always be greater than len(idx_list), making the conditional irrelevant
@fabianp
scikit-learn member

I'm currently trying the ranking stuff

@fabianp fabianp commented on the diff Dec 21, 2012
sklearn/linear_model/stochastic_gradient.py
+
+ loss_functions = {
+ "hinge": (Hinge, 1.0),
+ "squared_hinge": (SquaredHinge, 1.0),
+ "perceptron": (Hinge, 0.0),
+ "log": (Log, ),
+ "modified_huber": (ModifiedHuber, ),
+ "squared_loss": (SquaredLoss, ),
+ "huber": (Huber, DEFAULT_EPSILON),
+ "epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON),
+ }
+
+ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
+ l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=False,
+ verbose=0, epsilon=DEFAULT_EPSILON, n_jobs=1, seed=0,
+ learning_rate="optimal", eta0=0.0, power_t=0.5,
@fabianp
scikit-learn member
fabianp added a line comment Dec 21, 2012

what is epsilon ? It doesn't appear in the docstring

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@fabianp fabianp commented on the diff Dec 21, 2012
sklearn/linear_model/stochastic_gradient.py
+ loss_functions = {
+ "hinge": (Hinge, 1.0),
+ "squared_hinge": (SquaredHinge, 1.0),
+ "perceptron": (Hinge, 0.0),
+ "log": (Log, ),
+ "modified_huber": (ModifiedHuber, ),
+ "squared_loss": (SquaredLoss, ),
+ "huber": (Huber, DEFAULT_EPSILON),
+ "epsilon_insensitive": (EpsilonInsensitive, DEFAULT_EPSILON),
+ }
+
+ def __init__(self, loss="hinge", penalty='l2', alpha=0.0001,
+ l1_ratio=0.15, fit_intercept=True, n_iter=5, shuffle=False,
+ verbose=0, epsilon=DEFAULT_EPSILON, n_jobs=1, seed=0,
+ learning_rate="optimal", eta0=0.0, power_t=0.5,
+ class_weight=None, warm_start=False, rho=None,
@fabianp
scikit-learn member
fabianp added a line comment Dec 21, 2012

and rho ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@fabianp fabianp commented on the diff Dec 21, 2012
sklearn/linear_model/stochastic_gradient.py
+
+ def predict(self, X):
+ """Returns an index that ranks a test set according to the ranking
+ model.
+
+ Parameters
+ ----------
+ X : array-like, shape = [n_samples, n_features]
+ Test set.
+
+ Returns
+ -------
+ order_inv : array-like, shape = [n_samples]
+ """
+ order = np.argsort(np.dot(X, self.coef_[0]))
+ order_inv = np.zeros_like(order)
@fabianp
scikit-learn member
fabianp added a line comment Dec 21, 2012

Won't work on sparse arrays. Substitute np.dot(X, self.coef_[0]) by X.dot(self.coef[0])`

@fabianp
scikit-learn member
fabianp added a line comment Dec 21, 2012

or safe_sparse_dot

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
@fabianp
scikit-learn member

Hi,

I tried this implementation on the LETOR (ranking) dataset.

Here is the notebook that resumes my results: http://nbviewer.ipython.org/url/fseoane.net/tmp/2012/SGDRanking.ipynb . I don't know much about SGD so chances are I'm not setting the parameters correctly, but using it as a black box doesn't yield good results: it's several factors slower than sofia-ml and the scores are worse.

@fabianp
scikit-learn member

And another comparison with simulated data and no query_id: http://nbviewer.ipython.org/url/fseoane.net/tmp/2012/SGDRanking_simulated.ipynb

In this one the scores look right, it just reveals the timing difference

@larsmans larsmans commented on the diff Feb 13, 2013
sklearn/linear_model/stochastic_gradient.py
+ assert y_i.shape[0] == y.shape[0] == sample_weight.shape[0]\
+ == query_id_i.shape[0]
+ sampling_type = est._get_sampling_type(est.sampling)
+ dataset, intercept_decay = _make_dataset(X, y_i, sample_weight,
+ query_id_i, sampling_type)
+ penalty_type = est._get_penalty_type(est.penalty)
+ learning_rate_type = est._get_learning_rate_type(est.learning_rate)
+ return ranking_sgd(coef, intercept, est.loss_function,
+ penalty_type, alpha, est.l1_ratio,
+ dataset, n_iter, int(est.fit_intercept),
+ int(est.verbose), int(est.shuffle), est.seed,
+ learning_rate_type, est.eta0,
+ est.power_t, est.t_, intercept_decay, sampling_type)
+
+
+class SGDRanking(SGDClassifier):
@larsmans
scikit-learn member
larsmans added a line comment Feb 13, 2013

Why does this derive from SGDClassifier? It seems like it should derive from BaseSGD instead, or maybe some of the common code should be factored out to a new intermediate class.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment