/
setup.py
62 lines (53 loc) · 1.83 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from setuptools import setup
from setuptools import find_packages
from setuptools import Extension
import numpy
try:
from Cython.Build import cythonize
CAN_CYTHONIZE = True
except ImportError:
CAN_CYTHONIZE = False
def get_svmrank_parser_ext():
"""
Gets the svmrank parser extension.
This uses cython if possible when building from source, otherwise uses the
packaged .c files to compile directly.
"""
path = "pytorchltr/datasets/svmrank/parser"
pyx_path = os.path.join(path, "svmrank_parser.pyx")
c_path = os.path.join(path, "svmrank_parser.c")
if CAN_CYTHONIZE and os.path.exists(pyx_path):
return cythonize([Extension(
"pytorchltr.datasets.svmrank.parser.svmrank_parser", [pyx_path],
define_macros=[("NPY_NO_DEPRECATED_API", "NPY_1_7_API_VERSION")])])
else:
return [Extension("pytorchltr.datasets.svmrank.parser.svmrank_parser",
[c_path])]
with open("README.md", "rt") as f:
long_description = f.read()
setup(
name="pytorchltr",
version="0.2.1",
description="Learning to Rank with PyTorch",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/rjagerman/pytorchltr",
author="Rolf Jagerman",
author_email="rjagerman@gmail.com",
license="MIT",
packages=find_packages(exclude=("tests", "tests.*",)),
python_requires='>=3.5',
ext_modules=get_svmrank_parser_ext(),
include_dirs=[numpy.get_include()],
install_requires=["numpy",
"scikit-learn",
"scipy",
"torch"],
tests_require=["pytest"],
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
)