Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: core: support autotune #19

Merged
merged 5 commits into from Jan 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 17 additions & 0 deletions README.rst
Expand Up @@ -136,6 +136,23 @@ These wrappers assume the ``X`` parameter given to ``fit``, ``predict``, and ``p
>>> sk_clf.predict(['woof'])
>>> sk_clf.predict(df['txt'])

Hyper-parameter auto-tuning
----------------------------

It's possible to pass a validation set to ``fit`` in order to optimize the hyper-parameters.
The `auto-tune settings <https://fasttext.cc/docs/en/autotune.html>`_ can be passed to the constructor. E.g.,

.. code-block:: python

>>> from skift import SeriesFtClassifier
>>> df = pandas.DataFrame([['woof', 0], ['meow', 1]], columns=['txt', 'lbl'])
>>> df_val = pandas.DataFrame([['woof woof', 0], ['meow meow', 1]], columns=['txt', 'lbl'])
>>> sk_clf = SeriesFtClassifier(input_col_lbl='txt', epoch=8, autotuneDuration=5)
>>> sk_clf.fit(df['txt'], df['lbl'], df2['txt'], df2['lbl'])
>>> sk_clf.predict(['woof'])
>>> sk_clf.predict(df['txt'])



Contributing
============
Expand Down
52 changes: 44 additions & 8 deletions skift/core.py
Expand Up @@ -95,7 +95,7 @@ def _validate_y(y):
def _input_col(self, X):
pass # pragma: no cover

def fit(self, X, y):
def fit(self, X, y, X_validation=None, y_validation=None):
"""Fits the classifier

Parameters
Expand All @@ -104,6 +104,10 @@ def fit(self, X, y):
The training input samples.
y : array-like, shape = [n_samples]
The target values. An array of int.
X_validation : array-like, shape = [n_samples, n_features]
The validation input samples.
y_validation : array-like, shape = [n_samples]
The validation target values. An array of int.

Returns
-------
Expand All @@ -114,10 +118,16 @@ def fit(self, X, y):
self._validate_x(X)
y = self._validate_y(y)
input_col = self._input_col(X)
if X_validation is not None:
self._validate_x(X_validation)
y_validation = self._validate_y(y_validation)
input_col_validation = self._input_col(X_validation)
else:
input_col_validation = None

return self._fit_input_col(input_col, y)
return self._fit_input_col(input_col, y, input_col_validation, y_validation)

def _fit_input_col(self, input_col, y):
def _fit_input_col(self, input_col, y, input_col_validation=None, y_validation=None):
# Store the classes seen during fit
self.classes_ = unique_labels(y)
self.num_classes_ = len(self.classes_)
Expand All @@ -126,9 +136,23 @@ def _fit_input_col(self, input_col, y):
# Dump training set to a fasttext-compatible file
temp_trainset_fpath = temp_dataset_fpath()
dump_xy_to_fasttext_format(input_col, y, temp_trainset_fpath)
# train
self.model = train_supervised(
input=temp_trainset_fpath, **self.kwargs)
if input_col_validation is not None:
n_classes_validation = len(unique_labels(y_validation))
assert n_classes_validation == self.num_classes_,\
"Number of validation classes doesn't match number of training classes"
temp_trainset_fpath_validation = temp_dataset_fpath()
dump_xy_to_fasttext_format(input_col_validation, y_validation, temp_trainset_fpath_validation)
# train
self.model = train_supervised(
input=temp_trainset_fpath, **{'autotuneValidationFile': temp_trainset_fpath_validation, **self.kwargs})
try:
os.remove(temp_trainset_fpath_validation)
except FileNotFoundError: # pragma: no cover
pass
else:
self.model = train_supervised(
input=temp_trainset_fpath, **self.kwargs)

# Return the classifier
try:
os.remove(temp_trainset_fpath)
Expand Down Expand Up @@ -372,7 +396,7 @@ def __init__(self, **kwargs):
def _input_col(self, X):
pass

def fit(self, X, y):
def fit(self, X, y, X_validation=None, y_validation=None):
"""Fits the classifier

Parameters
Expand All @@ -381,6 +405,10 @@ def fit(self, X, y):
The training input samples.
y : array-like, shape = [n_samples]
The target values. An array of int.
X_validation : pd.Series
The validation input samples.
y_validation : array-like, shape = [n_samples]
The validation target values. An array of int.

Returns
-------
Expand All @@ -393,7 +421,15 @@ def fit(self, X, y):
except AttributeError:
input_col = X
y = self._validate_y(y)
return self._fit_input_col(input_col, y)
if X_validation is not None:
try:
input_col_validation = X_validation.values
except AttributeError:
input_col_validation = X_validation
y_validation = self._validate_y(y_validation)
else:
input_col_validation = None
return self._fit_input_col(input_col, y, input_col_validation, y_validation)

def _predict(self, X, k=1):
# Ensure that fit had been called
Expand Down
53 changes: 52 additions & 1 deletion tests/test_common.py
Expand Up @@ -20,6 +20,13 @@ def _ftdf():
)


def _ftdf2():
return pd.DataFrame(
data=[['woof', 0], ['meow', 1]],
columns=['txt', 'lbl']
)


def test_bad_shape():
ft_clf = FirstColFtClassifier()
with pytest.raises(ValueError):
Expand All @@ -28,6 +35,28 @@ def test_bad_shape():
ft_clf.fit([[7]], [[0]])


def test_series_predict_tune():
ftdf = _ftdf()
ftdf2 = _ftdf2()
ft_clf = SeriesFtClassifier(autotuneDuration=5)
ft_clf.fit(ftdf['txt'], ftdf['lbl'], ftdf2['txt'], ftdf2['lbl'])

preds = ft_clf.predict(ftdf['txt'])
assert preds[0] == 0
assert preds[1] == 1


def test_series_np_predict_tune():
ftdf = _ftdf()
ftdf2 = _ftdf2()
ft_clf = SeriesFtClassifier(autotuneDuration=5)
ft_clf.fit(ftdf['txt'].values, ftdf['lbl'].values, ftdf2['txt'].values, ftdf2['lbl'].values)

preds = ft_clf.predict(ftdf['txt'])
assert preds[0] == 0
assert preds[1] == 1


def test_series_predict():
ftdf = _ftdf()
ft_clf = SeriesFtClassifier()
Expand All @@ -38,6 +67,29 @@ def test_series_predict():
assert preds[1] == 1


def test_series_np_predict():
ftdf = _ftdf()
ft_clf = SeriesFtClassifier()
ft_clf.fit(ftdf['txt'].values, ftdf['lbl'].values)

preds = ft_clf.predict(ftdf['txt'].values)
assert preds[0] == 0
assert preds[1] == 1


def test_predict_tune():
ftdf = _ftdf()
ftdf2 = _ftdf2()
ft_clf = FirstColFtClassifier(autotuneDuration=5)
ft_clf.fit(ftdf[['txt']], ftdf['lbl'], X_validation=ftdf2[['txt']], y_validation=ftdf2['lbl'])

assert ft_clf.predict([['woof woof']])[0] == 0
assert ft_clf.predict([['meow meow']])[0] == 1
assert ft_clf.predict([['meow']])[0] == 1
assert ft_clf.predict([['woof lol']])[0] == 0
assert ft_clf.predict([['meow lolz']])[0] == 1


def test_predict():
ftdf = _ftdf()
ft_clf = FirstColFtClassifier()
Expand All @@ -49,7 +101,6 @@ def test_predict():
assert ft_clf.predict([['woof lol']])[0] == 0
assert ft_clf.predict([['meow lolz']])[0] == 1


def test_predict_proba():
ftdf = _ftdf()
ft_clf = FirstColFtClassifier()
Expand Down