Skip to content

Commit

Permalink
API + fix unit test on regression
Browse files Browse the repository at this point in the history
  • Loading branch information
sdpython committed Feb 10, 2018
1 parent d7086e6 commit cad53f3
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 6 deletions.
7 changes: 6 additions & 1 deletion _doc/sphinxdoc/source/api/datasets.rst
Expand Up @@ -6,7 +6,12 @@ Jeux de données
.. contents::
:local:

Recommandations
===============

.. autosignature:: papierstat.datasets.load_movielens_dataset

Régression
++++++++++
==========

.. autosignature:: papierstat.datasets.load_wines_dataset
1 change: 1 addition & 0 deletions _doc/sphinxdoc/source/api/index.rst
Expand Up @@ -7,3 +7,4 @@ API
:maxdepth: 1

datasets
mltricks
12 changes: 12 additions & 0 deletions _doc/sphinxdoc/source/api/mltricks.rst
@@ -0,0 +1,12 @@

===========================
Astuces de machine learning
===========================

.. contents::
:local:

Autour de scikit-learn
======================

.. autosignature:: papierstat.mltricks.sklearn_base_transform_learner.SkBaseTransformLearner
8 changes: 4 additions & 4 deletions _unittests/ut_mltricks/test_sklearn_convert.py
Expand Up @@ -42,7 +42,7 @@
from sklearn.datasets import load_iris
from sklearn.linear_model import LogisticRegression, LinearRegression
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.metrics import accuracy_score
from sklearn.metrics import accuracy_score, r2_score
from sklearn.pipeline import make_pipeline
from src.papierstat.mltricks import SkBaseTransformLearner

Expand Down Expand Up @@ -105,13 +105,13 @@ def test_pipeline_with_two_regressors(self):
pipe = make_pipeline(conv, DecisionTreeRegressor())
pipe.fit(X_train, y_train)
pred = pipe.predict(X_test)
score = accuracy_score(y_test, pred)
self.assertGreater(score, 0.92)
score = r2_score(y_test, pred)
self.assertLesser(score, 1.)
score2 = pipe.score(X_test, y_test)
self.assertEqual(score, score2)
rp = repr(conv)
self.assertStartsWith(
'SkBaseTransformLearner(model=LogisticRegression(C=1.0,', rp)
'SkBaseTransformLearner(model=LinearRegression(copy_X=True,', rp)


if __name__ == "__main__":
Expand Down
6 changes: 5 additions & 1 deletion src/papierstat/mltricks/sklearn_base_transform_learner.py
Expand Up @@ -4,6 +4,7 @@
@brief Implémente un *transform* qui suit la même API que tout :epkg:`scikit-learn` transform.
"""
import textwrap
import numpy
from .sklearn_base_transform import SkBaseTransform


Expand Down Expand Up @@ -110,7 +111,10 @@ def transform(self, X):
@param X features
@return prédictions
"""
return self.method(X)
res = self.method(X)
if len(res.shape) == 1:
res = res[:, numpy.newaxis]
return res

##############
# cloning API
Expand Down

0 comments on commit cad53f3

Please sign in to comment.