Skip to content

Commit

Permalink
Merge pull request #159 from rth/more-fixes-setup
Browse files Browse the repository at this point in the history
MAINT More fixes to include test/ folder in sdist
  • Loading branch information
rtavenar committed Oct 11, 2019
2 parents 39ad33b + b7499aa commit 1721482
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 9 deletions.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup
from setuptools import setup, find_packages
from codecs import open
import numpy
import os
Expand All @@ -22,7 +22,7 @@
long_description=README,
long_description_content_type='text/markdown',
include_dirs=[numpy.get_include()],
packages=['tslearn'],
packages=find_packages(),
package_data={"tslearn": [".cached_datasets/Trace.npz"]},
data_files=[("", ["LICENSE"])],
install_requires=['numpy', 'scipy', 'scikit-learn', 'Cython', 'numba',
Expand Down
2 changes: 1 addition & 1 deletion tslearn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os

__author__ = 'Romain Tavenard romain.tavenard[at]univ-rennes2.fr'
__version__ = "0.2.4"
__version__ = "0.2.5"
__bibtex__ = r"""@misc{tslearn,
title={tslearn: A machine learning toolkit dedicated to time-series data},
author={Tavenard, Romain and Faouzi, Johann and Vandewiele, Gilles},
Expand Down
24 changes: 18 additions & 6 deletions tslearn/tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,19 @@ def _get_all_classes():
base_path = tslearn.__path__
for _, name, _ in pkgutil.walk_packages(path=base_path,
prefix='tslearn.'):
module = __import__(name, fromlist="dummy")
try:
module = __import__(name, fromlist="dummy")
except ImportError:
if name.endswith('shapelets'):
# keras is likely not installed
warnings.warn('Skipped common tests for shapelets '
'as it could not be imported. keras '
'(and tensorflow) are probably not '
'installed!')
continue
else:
raise

all_classes.extend(inspect.getmembers(module, inspect.isclass))
return all_classes

Expand Down Expand Up @@ -169,11 +181,11 @@ def check_estimator(Estimator):
warnings.warn(str(exception), SkipTestWarning)


@pytest.mark.parametrize('estimator', get_estimators('all'))
def test_all_estimators(estimator):
@pytest.mark.parametrize('name, Estimator', get_estimators('all'))
def test_all_estimators(name, Estimator):
"""Test all the estimators in tslearn."""
allow_nan = (hasattr(checks, 'ALLOW_NAN') and
_safe_tags(estimator[1](), "allow_nan"))
_safe_tags(Estimator(), "allow_nan"))
if allow_nan:
checks.ALLOW_NAN.append(estimator[0])
check_estimator(estimator[1])
checks.ALLOW_NAN.append(name)
check_estimator(Estimator)

0 comments on commit 1721482

Please sign in to comment.