diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..70939d4 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,22 @@ +[paths] +source = + omnipath + */site-packages/omnipath + +[run] +branch = true +parallel = true +source = omnipath +omit = */__init__.py + +[report] +exclude_lines = + \#.*pragma:\s*no.?cover + + if __name__ == .__main__. + + ^\s*raise AssertionError\b + ^\s*raise NotImplementedError\b + ^\s*return NotImplemented\b +show_missing = true +precision = 2 diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a50e670 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,66 @@ +name: CI + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + build: + runs-on: ${{ matrix.os }} + timeout-minutes: 10 + strategy: + max-parallel: 4 + matrix: + python: [3.7, 3.8] # , 3.9] + os: [ubuntu-latest, macos-latest] + env: + OS: ${{ matrix.os }} + PYTHON: ${{ matrix.python }} + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + submodules: false + + - name: Get pip cache dir + id: pip-cache + run: | + echo "::set-output name=dir::$(pip cache dir)" + + - name: Cache pip + uses: actions/cache@v2 + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pip- + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install tox tox-gh-actions codecov + + - name: Linting + run: | + tox -e lint + + - name: Testing + run: | + tox + env: + PLATFORM: ${{ matrix.platform }} + + - name: Upload coverage to Codecov + if: success() + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + CODECOV_NAME: ${{ matrix.python }}-${{ matrix.os }} + run: | + codecov --no-color --required --flags unittests diff --git a/README.rst b/README.rst index 8d6eb4a..18f4c84 100644 --- a/README.rst +++ b/README.rst @@ -1,13 +1,8 @@ -|PyPI| |Travis| |Docs| +|PyPI| |CI| |Docs| |Coverage| Python client for the OmniPath web service ========================================== -**This package is in planning stage, without any useful functionality yet.** -Contributions are welcome, please contact us at omnipathdb@gmail.com, open -issues or send pull requests. Otherwise please check out the resources below -and return to us later. - The OmniPath database --------------------- @@ -65,10 +60,14 @@ certain (not all) annotations of the proteins. :target: https://pypi.org/project/omnipath :alt: PyPI -.. |Travis| image:: https://travis-ci.org/theislab/omnipath.svg?branch=master - :target: https://travis-ci.com/github/saezlab/omnipath +.. |CI| image:: https://img.shields.io/github/workflow/status/michalk8/omnipath/CI/master + :target: https://github.com/michalk8/omnipath/actions?query=workflow:CI :alt: CI +.. |Coverage| image:: https://codecov.io/gh/michalk8/omnipath/branch/master/graph/badge.svg?token=5A086KQA51 + :target: https://codecov.io/gh/michalk8/omnipath + :alt: Coverage + .. |Docs| image:: https://img.shields.io/readthedocs/omnipath :target: https://omnipath.readthedocs.io/en/latest :alt: Documentation diff --git a/docs/source/index.rst b/docs/source/index.rst index aaf1d0c..e8cfd8f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -|PyPI| |Travis| |Docs| +|PyPI| |CI| |Docs| |Coverage| OmniPath ======== @@ -20,13 +20,17 @@ This package is a Python equivalent of an R package `OmnipathR`_ for accessing w :target: https://pypi.org/project/omnipath :alt: PyPI -.. |Travis| image:: https://travis-ci.org/theislab/omnipath.svg?branch=master - :target: https://travis-ci.com/github/saezlab/omnipath +.. |CI| image:: https://img.shields.io/github/workflow/status/michalk8/omnipath/CI/master + :target: https://github.com/michalk8/omnipath/actions?query=workflow:CI :alt: CI .. |Docs| image:: https://img.shields.io/readthedocs/omnipath :target: https://omnipath.readthedocs.io/en/latest :alt: Documentation +.. |Coverage| image:: https://codecov.io/gh/michalk8/omnipath/branch/master/graph/badge.svg?token=5A086KQA51 + :target: https://codecov.io/gh/michalk8/omnipath + :alt: Coverage + .. _Saezlab : https://saezlab.org/ .. _OmniPathR : https://github.com/saezlab/omnipathR diff --git a/omnipath/__init__.py b/omnipath/__init__.py index 29a7396..cab3c70 100644 --- a/omnipath/__init__.py +++ b/omnipath/__init__.py @@ -21,6 +21,6 @@ __full_version__ = ( f"{__version__}+{__full_version__.local}" if __full_version__.local else __version__ ) -__server_version__ = _get_server_version() +__server_version__ = _get_server_version(options) del parse, version, _get_server_version diff --git a/omnipath/_core/cache/_cache.py b/omnipath/_core/cache/_cache.py index 6bac12f..24efa04 100644 --- a/omnipath/_core/cache/_cache.py +++ b/omnipath/_core/cache/_cache.py @@ -125,6 +125,10 @@ def __repr__(self) -> str: def __copy__(self) -> "MemoryCache": return self + def copy(self) -> "MemoryCache": + """Return self.""" + return self + def clear_cache() -> None: """Remove all cached data from :attr:`omnipath.options.cache`.""" diff --git a/omnipath/_core/downloader/_downloader.py b/omnipath/_core/downloader/_downloader.py index 0c112aa..3ef7543 100644 --- a/omnipath/_core/downloader/_downloader.py +++ b/omnipath/_core/downloader/_downloader.py @@ -2,7 +2,6 @@ from copy import copy from typing import Any, Mapping, Callable, Optional from hashlib import md5 -from functools import lru_cache from urllib.parse import urljoin import json import logging @@ -37,15 +36,20 @@ def __init__(self, opts: Optional[Options] = None): if opts is None: from omnipath import options as opts + if not isinstance(opts, Options): + raise TypeError( + f"Expected `opts` to be of type `Options`, found {type(opts).__name__}." + ) + self._session = Session() - self._options = copy(opts) + self._options = copy(opts) # this does not copy MemoryCache if self._options.num_retries > 0: adapter = HTTPAdapter( max_retries=Retry( total=self._options.num_retries, redirect=5, - method_whitelist=["HEAD", "GET", "OPTIONS"], + allowed_methods=["HEAD", "GET", "OPTIONS"], status_forcelist=[413, 429, 500, 502, 503, 504], backoff_factor=1, ) @@ -176,13 +180,10 @@ def __repr__(self) -> str: return str(self) -@lru_cache() -def _get_server_version() -> str: +def _get_server_version(options: Options) -> str: """Try and get the server version.""" import re - from omnipath import options - def callback(fp: BytesIO) -> str: """Parse the version.""" return re.findall( @@ -191,7 +192,7 @@ def callback(fp: BytesIO) -> str: try: if not options.autoload: - raise ValueError("Autoload is disallowed.") + raise ValueError("Autoload is disabled.") with Options.from_options( options, @@ -209,6 +210,6 @@ def callback(fp: BytesIO) -> str: is_final=False, ) except Exception as e: - logging.debug(f"Unable to get server version. Reason `{e}`") + logging.debug(f"Unable to get server version. Reason: `{e}`") return UNKNOWN_SERVER_VERSION diff --git a/omnipath/_core/query/_query.py b/omnipath/_core/query/_query.py index c028e9e..fcd1597 100644 --- a/omnipath/_core/query/_query.py +++ b/omnipath/_core/query/_query.py @@ -6,7 +6,6 @@ from omnipath.constants._constants import FormatterMeta, ErrorFormatter from omnipath._core.query._query_validator import ( - DummyValidator, EnzsubValidator, ComplexesValidator, IntercellValidator, @@ -58,8 +57,7 @@ def __new__(cls, clsname, superclasses, attributedict): # noqa: D102 for i, synonym in enumerate(_get_synonyms(k.lower())): attributedict[f"{k}_{i}"] = synonym - res = super().__new__(cls, clsname, superclasses, attributedict) - return res + return super().__new__(cls, clsname, superclasses, attributedict) class QueryMeta(SynonymizerMeta, FormatterMeta): # noqa: D101 @@ -69,13 +67,13 @@ class QueryMeta(SynonymizerMeta, FormatterMeta): # noqa: D101 class Query(ErrorFormatter, Enum, metaclass=QueryMeta): # noqa: D101 @property def _query_name(self) -> str: - """Convert synonym to actual query parameter name.""" + """Convert the synonym to an actual query parameter name.""" return "_".join(self.name.split("_")[:-1]) @property - def _delagate(self): + def _delegate(self): """Delegate the validation.""" - return getattr(self.__validator__, self._query_name, DummyValidator) + return getattr(self.__validator__, self._query_name) @property def param(self) -> str: @@ -85,23 +83,23 @@ def param(self) -> str: @property def valid(self) -> Optional[FrozenSet[str]]: """Return the set of valid values for :paramref:`param`.""" - return self._delagate.valid + return self._delegate.valid @property def annotation(self) -> type: """Return type annotations for :paramref:`param`.""" - return self._delagate.annotation + return self._delegate.annotation @property def doc(self) -> Optional[str]: """Return the docstring for :paramref:`param`.""" - return self._delagate.doc + return self._delegate.doc def __call__( self, value: Optional[Union[str, Sequence[str]]] ) -> Optional[Set[str]]: """%(validate)s""" # noqa: D401 - return self._delagate(value) + return self._delegate(value) class EnzsubQuery(Query): # noqa: D101 @@ -139,5 +137,14 @@ def __call__( @property def endpoint(self) -> str: - """Get the API endpoint for this type of query..""" + """Get the API endpoint for this type of query.""" return self.name.lower() + + +__all__ = [ + EnzsubQuery, + InteractionsQuery, + ComplexesQuery, + AnnotationsQuery, + IntercellQuery, +] diff --git a/omnipath/_core/query/_query_validator.py b/omnipath/_core/query/_query_validator.py index 2f13749..05eb00c 100644 --- a/omnipath/_core/query/_query_validator.py +++ b/omnipath/_core/query/_query_validator.py @@ -22,14 +22,16 @@ Bool_t, None_t, Strseq_t, + License_t, Organism_t, ) from omnipath._core.utils._options import Options from omnipath.constants._constants import NoValue +from omnipath.constants._pkg_constants import Key, Format from omnipath._core.downloader._downloader import Downloader -def _to_string_set(item: Union[str, Sequence[str]]) -> Set[str]: +def _to_string_set(item: Union[Any, Sequence[Any]]) -> Set[str]: """ Convert ``item`` to a `str` set. @@ -107,6 +109,8 @@ def __call__(self, needle: Optional[Set[str]]) -> Optional[Set[str]]: """ if needle is None: return None + elif isinstance(needle, bool): + needle = int(needle) needle = _to_string_set(needle) if self.haystack is None: @@ -137,7 +141,7 @@ def __new__(cls, clsname, superclasses, attributedict): # noqa: D102 ) use_default = True old_members = list(attributedict._member_names) - old_values = [] + old_values = cls._remove_old_members(attributedict) if endpoint is None: if len(old_members): @@ -156,15 +160,12 @@ def __new__(cls, clsname, superclasses, attributedict): # noqa: D102 ) as opt: try: logging.debug("Attempting to construct classes from the server") - - url = urljoin(urljoin(opt.url, "queries/"), endpoint) res = Downloader(opt).maybe_download( - url, callback=json.load, params={"format": "json"} + urljoin(urljoin(opt.url, f"{Key.QUERIES.s}/"), endpoint), + callback=json.load, + params={Key.FORMAT.s: Format.JSON.s}, ) - # remove the default values - old_values = cls._remove_old_members(attributedict) - if len({str(k).upper() for k in res.keys()}) != len(res): raise RuntimeError( f"After upper casing, key will not be unique: `{list(res.keys())}`." @@ -188,16 +189,16 @@ def __new__(cls, clsname, superclasses, attributedict): # noqa: D102 attributedict[key] = cls.Validator(param=k) except Exception as e: logging.debug( - f"Unable to construct classes from the server. Reason `{e}`" + f"Unable to construct classes from the server. Reason: `{e}`" ) use_default = True if use_default: - if endpoint not in (None, "dummy"): + if endpoint is not None: logging.debug( - f"Using predifined class: `{clsname}`." + "" + f"Using predefined class: `{clsname}`." + "" if options.autoload - else " Consider specifying `omnipath.options.autoload=True`" + else " Consider specifying `omnipath.options.autoload = True`" ) _ = cls._remove_old_members(attributedict) @@ -250,7 +251,7 @@ class EnzsubValidator(QueryValidatorMixin): # noqa: D101 FORMAT: Str_t = () GENESYMBOLS: Bool_t = () HEADER: Str_t = () - LICENSE: Str_t = () + LICENSE: License_t = () LIMIT: Int_t = () MODIFICATION: Str_t = () ORGANISMS: Organism_t = () @@ -273,7 +274,7 @@ class InteractionsValidator(QueryValidatorMixin): # noqa: D101 FORMAT: Str_t = () GENESYMBOLS: Bool_t = () HEADER: Str_t = () - LICENSE: Str_t = () + LICENSE: License_t = () LIMIT: Int_t = () ORGANISMS: Organism_t = () PARTNERS: Strseq_t = () @@ -293,7 +294,7 @@ class ComplexesValidator(QueryValidatorMixin): # noqa: D101 FIELDS: Strseq_t = () FORMAT: Str_t = () HEADER: Str_t = () - LICENSE: Str_t = () + LICENSE: License_t = () LIMIT: Int_t = () PASSWORD: Str_t = () PROTEINS: Strseq_t = () @@ -307,7 +308,7 @@ class AnnotationsValidator(QueryValidatorMixin): # noqa: D101 FORMAT: Str_t = () GENESYMBOLS: Bool_t = () HEADER: Str_t = () - LICENSE: Str_t = () + LICENSE: License_t = () LIMIT: Int_t = () PASSWORD: Str_t = () PROTEINS: Strseq_t = () @@ -323,7 +324,7 @@ class IntercellValidator(QueryValidatorMixin): # noqa: D101 FIELDS: Strseq_t = () FORMAT: Str_t = () HEADER: None_t = () - LICENSE: Str_t = () + LICENSE: License_t = () LIMIT: Int_t = () PARENT: Str_t = () PASSWORD: Str_t = () @@ -344,10 +345,6 @@ class IntercellValidator(QueryValidatorMixin): # noqa: D101 TRANSMITTER: Bool_t = () -class DummyValidator(QueryValidatorMixin): # noqa: D101 - pass - - __all__ = [ EnzsubValidator, InteractionsValidator, diff --git a/omnipath/_core/query/_types.py b/omnipath/_core/query/_types.py index 771698a..eb3fcea 100644 --- a/omnipath/_core/query/_types.py +++ b/omnipath/_core/query/_types.py @@ -1,9 +1,14 @@ -from typing import Union, Literal, Optional, Sequence +from typing import Union, Optional, Sequence + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal -# TODO: I don't think 3.6 has Literal Strseq_t = Optional[Union[str, Sequence[str]]] Organism_t = Literal["human", "mouse", "rat"] +License_t = Literal["academic", "commercial"] Bool_t = Optional[bool] Str_t = Optional[str] Int_t = Optional[int] diff --git a/omnipath/_core/requests/_annotations.py b/omnipath/_core/requests/_annotations.py index d2b4097..8271493 100644 --- a/omnipath/_core/requests/_annotations.py +++ b/omnipath/_core/requests/_annotations.py @@ -1,11 +1,11 @@ -from typing import Any, Dict, Union, Mapping, Iterable, final +from typing import Any, Dict, Union, Mapping, Iterable import pandas as pd from omnipath._core.query import QueryType from omnipath._core.utils._docs import d from omnipath._core.requests._request import OmnipathRequestABC -from omnipath.constants._pkg_constants import Key +from omnipath.constants._pkg_constants import Key, final @final @@ -22,6 +22,7 @@ def _remove_params(self, params: Dict[str, Any]) -> Dict[str, Any]: return params + @classmethod @d.dedent def params(cls) -> Dict[str, Any]: """%(query_params)s""" @@ -68,10 +69,10 @@ def get( if len(proteins) > 600: raise ValueError( - "Cannot download information for more than `600` proteins yet." + "Cannot download annotations for more than `600` proteins yet." ) - return super()._get(proteins=proteins, **kwargs) + return cls()._get(proteins=proteins, **kwargs) def _resource_filter(self, data: Mapping[str, Any], **_) -> bool: return True diff --git a/omnipath/_core/requests/_complexes.py b/omnipath/_core/requests/_complexes.py index ebf0eb4..488c2fe 100644 --- a/omnipath/_core/requests/_complexes.py +++ b/omnipath/_core/requests/_complexes.py @@ -1,14 +1,15 @@ -from typing import Any, Union, Mapping, Iterable, Optional, final +from typing import Any, Union, Mapping, Iterable, Optional import logging import pandas as pd from omnipath._core.query import QueryType -from omnipath._core.requests._request import CommonPostProcessor +from omnipath._core.requests._request import OrganismGenesymbolsRemover +from omnipath.constants._pkg_constants import final @final -class Complexes(CommonPostProcessor): +class Complexes(OrganismGenesymbolsRemover): """Request information about protein complexes from [OmniPath]_.""" __string__ = frozenset( diff --git a/omnipath/_core/requests/_intercell.py b/omnipath/_core/requests/_intercell.py index e8e33ee..9ff9988 100644 --- a/omnipath/_core/requests/_intercell.py +++ b/omnipath/_core/requests/_intercell.py @@ -1,15 +1,16 @@ -from typing import Any, Tuple, Mapping, Iterable, Optional, Sequence, final +from typing import Any, Tuple, Mapping, Iterable, Optional, Sequence import pandas as pd from omnipath._core.query import QueryType from omnipath._core.query._types import Strseq_t -from omnipath._core.requests._request import CommonPostProcessor -from omnipath.constants._pkg_constants import Key, Format +from omnipath._core.requests._request import OrganismGenesymbolsRemover +from omnipath.constants._pkg_constants import Key, Format, final +from omnipath._core.query._query_validator import _to_string_set @final -class Intercell(CommonPostProcessor): +class Intercell(OrganismGenesymbolsRemover): """ Request `intercell` annotations from [OmniPath]_. @@ -32,9 +33,9 @@ def _resource_filter( generic_categories: Optional[Sequence[str]] = None, **kwargs, ) -> bool: - return generic_categories is None or set( + return generic_categories is None or _to_string_set( data.get(Key.GENERIC_CATEGORIES.s, set()) - ) & set(generic_categories) + ) & _to_string_set(generic_categories) @classmethod def resources(cls, generic_categories: Strseq_t = None) -> Tuple[str]: diff --git a/omnipath/_core/requests/_request.py b/omnipath/_core/requests/_request.py index 08e922a..2d367f1 100644 --- a/omnipath/_core/requests/_request.py +++ b/omnipath/_core/requests/_request.py @@ -10,15 +10,16 @@ Iterable, Optional, Sequence, - final, ) +from operator import itemgetter from functools import partial import logging from pandas.api.types import is_float_dtype, is_numeric_dtype import pandas as pd -from omnipath.constants import Organism +from omnipath import options +from omnipath.constants import License, Organism from omnipath._core.query import QueryType from omnipath._core.utils._docs import d from omnipath._core.requests._utils import ( @@ -26,7 +27,7 @@ _inject_api_method, _strip_resource_label, ) -from omnipath.constants._pkg_constants import DEFAULT_FIELD, Key, Format +from omnipath.constants._pkg_constants import DEFAULT_FIELD, Key, Format, final from omnipath._core.downloader._downloader import Downloader @@ -56,8 +57,6 @@ class OmnipathRequestABC(ABC, metaclass=OmnipathRequestMeta): _query_type: Optional[QueryType] = None def __init__(self): - from omnipath import options - self._downloader = Downloader(options) @classmethod @@ -83,16 +82,20 @@ def _docs(cls) -> Dict[str, Optional[str]]: return {q.param: q.doc for q in cls._query_type.value} def _get(self, **kwargs) -> pd.DataFrame: - kwargs, callback = self._convert_params(kwargs) - kwargs = self._inject_fields(kwargs) kwargs = self._remove_params(kwargs) + kwargs = self._inject_fields(kwargs) + kwargs, callback = self._convert_params(kwargs) kwargs = self._validate_params(kwargs) kwargs = self._finalize_params(kwargs) res = self._downloader.maybe_download( self._query_type.endpoint, params=kwargs, callback=callback, is_final=False ) - return self._post_process(self._convert_dtypes(res)) + + if self._downloader._options.convert_dtypes: + res = self._convert_dtypes(res) + + return self._post_process(res) def _convert_params( self, params: Dict[str, Any] @@ -105,12 +108,16 @@ def _convert_params( fmt = Format(params.get(Key.FORMAT.s, Format.TSV.s)) if fmt not in (Format.TSV, Format.JSON): logging.warning( - f"Invalid `{Key.FORMAT.s}={fmt.s!r}`. Switching to `{Key.FORMAT.s}={Format.TSV.s!r}`" + f"Invalid `{Key.FORMAT.s}={fmt.s!r}`. Using `{Key.FORMAT.s}={Format.TSV.s!r}`" ) fmt = Format.TSV callback = self._tsv_reader if fmt == Format.TSV else self._json_reader + params[Key.FORMAT.s] = fmt.s + params[Key.LICENSE.s] = License(params.get(Key.LICENSE.s, License.ACADEMIC)).s + if self._downloader._options.password is not None: + params.setdefault(Key.PASSWORD.s, self._downloader._options.password) return params, callback @@ -118,14 +125,16 @@ def _inject_fields(self, params: Dict[str, Any]) -> Dict[str, Any]: try: _inject_params( params, - key=self._query_type("fields").param, + key=self._query_type(Key.FIELDS.value).param, value=getattr(DEFAULT_FIELD, self._query_type.name).value, ) except AttributeError: # no default field for this query pass except Exception as e: - logging.warning(f"Unable to inject `fields` for `{self}`. Reason: `{e}`") + logging.warning( + f"Unable to inject `{Key.FIELDS.value}` for `{self}`. Reason: `{e}`" + ) return params @@ -134,7 +143,7 @@ def _validate_params( ) -> Dict[str, Optional[Union[str, Sequence[str]]]]: """For each passed parameter, validate if it has the correct value.""" for k, v in params.items(): - # first get the validator, then validate + # first get the validator for the parameter, then validate params[k] = self._query_type(k)(v) return params @@ -149,17 +158,16 @@ def _finalize_params(self, params: Dict[str, Any]) -> Dict[str, str]: elif isinstance(v, (int, float)): res[k] = str(v) elif isinstance(v, Iterable): - res[k] = ",".join(v) + res[k] = ",".join(sorted(v)) elif isinstance(v, Enum): res[k] = str(v.value) elif v is not None: logging.warning(f"Unable to process parameter `{k}={v}`. Ignoring") - return res + return dict(sorted(res.items(), key=itemgetter(0))) def _convert_dtypes(self, res: pd.DataFrame, **_) -> pd.DataFrame: """Automatically convert dtypes for this type of query.""" - from omnipath import options def to_logical(col: pd.Series) -> pd.Series: if is_numeric_dtype(col): @@ -190,10 +198,9 @@ def handle_string(df: pd.DataFrame, columns: frozenset) -> None: f"Expected the result to be of type `pandas.DataFrame`, found `{type(res).__name__}`." ) - if options.convert_dtypes: - handle_logical(res, self.__logical__) - handle_categorical(res, self.__categorical__) - handle_string(res, self.__string__) + handle_logical(res, self.__logical__) + handle_categorical(res, self.__categorical__) + handle_string(res, self.__string__) return res @@ -215,17 +222,28 @@ def _resources(self, **kwargs) -> Tuple[str]: sorted( res for res, params in self._downloader.resources.items() - if self._query_type.name.lower() in params[Key.QUERIES.s] + if self._query_type.endpoint in params.get(Key.QUERIES.s, {}) and self._resource_filter( - params[Key.QUERIES.s][self._query_type.name.lower()], **kwargs + params[Key.QUERIES.s][self._query_type.endpoint], **kwargs ) ) ) - # TODO: doc - @abstractmethod def _remove_params(self, params: Dict[str, Any]) -> Dict[str, Any]: - pass + """ + Remove parameters from this query. + + Parameters + ---------- + params + The parameters to filter. + + Returns + ------- + :class:`dict` + The filtered parameters. + """ + return params @abstractmethod def _post_process(self, df: pd.DataFrame) -> pd.DataFrame: @@ -279,22 +297,6 @@ class CommonPostProcessor(OmnipathRequestABC, ABC): :class:`omnipath.interactions.InteractionRequest` and :class:`omnipath.requests.Enzsub`. """ - def _remove_params(self, params: Dict[str, Any]) -> Dict[str, Any]: - params.pop(Key.ORGANISM.s, None) - params.pop(Key.GENESYMBOLS.s, None) - - return params - - @classmethod - @d.dedent - def params(cls) -> Dict[str, Any]: - """%(query_params)s""" - params = super().params() - params.pop(Key.ORGANISM.s, None) - params.pop(Key.GENESYMBOLS.s, None) - - return params - def _post_process(self, df: pd.DataFrame) -> pd.DataFrame: """ Add number of resources and references for each row in the resulting ``df``. @@ -332,12 +334,32 @@ def _post_process(self, df: pd.DataFrame) -> pd.DataFrame: return df +class OrganismGenesymbolsRemover(CommonPostProcessor, ABC): + """Class that removes organism and genesymbols keys from the query.""" + + def _remove_params(self, params: Dict[str, Any]) -> Dict[str, Any]: + params.pop(Key.ORGANISM.s, None) + params.pop(Key.GENESYMBOLS.s, None) + + return params + + @classmethod + @d.dedent + def params(cls) -> Dict[str, Any]: + """%(query_params)s""" + params = super().params() + params.pop(Key.ORGANISM.s, None) + params.pop(Key.GENESYMBOLS.s, None) + + return params + + @final -class Enzsub(CommonPostProcessor): +class Enzsub(OrganismGenesymbolsRemover): """ Request enzyme-substrate relationships from [OmniPath]_. - Imports the enzyme-substrate (more exactly, enzyme-PTM) relationship `database `__. + Imports the enzyme-substrate (more exactly, enzyme-PTM) relationships `database `__. """ __string__ = frozenset({"enzyme", "substrate"}) diff --git a/omnipath/_core/requests/_utils.py b/omnipath/_core/requests/_utils.py index d208a22..fbb9985 100644 --- a/omnipath/_core/requests/_utils.py +++ b/omnipath/_core/requests/_utils.py @@ -5,14 +5,15 @@ import inspect import wrapt +import typing_extensions # noqa: F401 import pandas as pd from omnipath._core.utils._docs import d -@d.get_full_descriptionf("get") -@d.get_sectionsf("get", sections=["Parameters", "Returns"]) +@d.get_full_description(base="get") +@d.get_sections(base="get", sections=["Parameters", "Returns"]) def _get_helper(cls: type, **kwargs) -> pd.DataFrame: """ Perform a request to the [OmniPath]_ web service. @@ -79,10 +80,12 @@ def argspec_factory(orig_fn: Callable) -> Callable: + [Parameter("kwargs", kind=Parameter.VAR_KEYWORD)] ) # modify locals() for argspec factory + import omnipath # noqa: F401 + NoneType, pandas = type(None), pd exec( - f"def adapter{sig}: pass", + f"def adapter{sig}: pass".replace(" /,", ""), globals(), locals(), ) @@ -126,16 +129,23 @@ def _inject_params( def _split_unique_join(data: pd.Series, func: Optional[Callable] = None) -> pd.Series: - data = data.astype(str).str.split(";") + mask = ~pd.isnull(data.astype("string")) + data = data[mask] + data = data.str.split(";") if func is None: - return data.apply( + data = data.apply( lambda row: ";".join(sorted(set(map(str, row)))) if isinstance(row, Iterable) - else "" + else row ) + else: + data = data.apply(func) + + res = pd.Series([None] * len(mask)) + res.loc[mask] = data - return data.apply(func) + return res def _strip_resource_label( diff --git a/omnipath/_core/requests/interactions/_interactions.py b/omnipath/_core/requests/interactions/_interactions.py index 1e86135..bedd574 100644 --- a/omnipath/_core/requests/interactions/_interactions.py +++ b/omnipath/_core/requests/interactions/_interactions.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Tuple, Union, Mapping, Iterable, Optional, Sequence, final +from typing import Any, Set, Dict, Tuple, Union, Mapping, Iterable, Optional, Sequence import pandas as pd @@ -8,11 +8,27 @@ from omnipath._core.utils._docs import d from omnipath._core.requests._utils import _inject_params from omnipath._core.requests._request import CommonPostProcessor -from omnipath.constants._pkg_constants import Key +from omnipath.constants._pkg_constants import Key, final Datasets_t = Union[str, InteractionDataset, Sequence[str], Sequence[InteractionDataset]] +def _to_dataset_set( + datasets, name: str, none_value: Iterable[InteractionDataset] +) -> Set[InteractionDataset]: + if isinstance(datasets, (str, InteractionDataset)): + datasets = (datasets,) + elif datasets is None: + datasets = none_value + + if not isinstance(datasets, Iterable): + raise TypeError( + f"Expected `{name}` to be an `Iterable`, found `{type(datasets).__name__}`." + ) + + return {InteractionDataset(d) for d in datasets} + + @d.dedent class InteractionRequest(CommonPostProcessor, ABC): """ @@ -51,23 +67,13 @@ def __init__( ): super().__init__() - if isinstance(datasets, (str, InteractionDataset)): - datasets = (datasets,) - elif datasets is None: - datasets = set(InteractionDataset) - - if not isinstance(datasets, Iterable): - raise TypeError( - f"Expected `datasets` to be an `Iterable`, found `{type(datasets).__name__}`." - ) - - datasets = {InteractionDataset(d) for d in datasets} - exclude = set() if exclude is None else {InteractionDataset(e) for e in exclude} + datasets = _to_dataset_set(datasets, "datasets", set(InteractionDataset)) + exclude = _to_dataset_set(exclude, "exclude", set()) datasets = datasets - exclude if not len(datasets): raise ValueError( - f"After excluding `{sorted(exclude)}` datasets, none were left." + f"After excluding `{len(exclude)}` datasets, none were left." ) self._datasets = datasets @@ -95,7 +101,7 @@ def _resource_filter( datasets: Optional[Sequence[InteractionDataset]] = None, ) -> bool: res = datasets is None or ( - {InteractionDataset(d) for d in data[Key.DATASETS.s]} + {InteractionDataset(d) for d in data.get(Key.DATASETS.s, {})} & {InteractionDataset(d) for d in datasets} ) @@ -216,11 +222,10 @@ class Transcriptional(InteractionRequest): """ Request all `TF-target` interactions of [OmniPath]_. - Imports the `dataset `__ which contains + Imports the `dataset `__ which contains transcription factor-target protein coding gene interactions. """ - # TODO: needs to be checked def __init__(self): super().__init__((InteractionDataset.DOROTHEA, InteractionDataset.TF_TARGET)) @@ -269,7 +274,7 @@ def __init__(self): @final -class OmniPath(CommonParamFilter): +class OmniPath(InteractionRequest): """ Request interactions from the `omnipath` dataset. @@ -279,6 +284,10 @@ class OmniPath(CommonParamFilter): This part of the interaction database was compiled in a similar way as it has been presented in [OmniPath16]_. """ + @classmethod + def _filter_params(cls, params: Dict[str, Any]) -> Dict[str, Any]: + return super()._filter_params(params) + __string__ = frozenset({"source", "target", "dip_url"}) __logical__ = frozenset( { @@ -327,7 +336,9 @@ def _filter_params(cls, params: Dict[str, Any]) -> Dict[str, Any]: @classmethod @d.dedent @final - def get(cls, exclude: Optional[Sequence[str]], **kwargs) -> pd.DataFrame: + def get( + cls, exclude: Optional[Sequence[Datasets_t]] = None, **kwargs + ) -> pd.DataFrame: """ %(get.full_desc)s diff --git a/omnipath/_core/utils/_options.py b/omnipath/_core/utils/_options.py index 27314d6..04deaa1 100644 --- a/omnipath/_core/utils/_options.py +++ b/omnipath/_core/utils/_options.py @@ -31,11 +31,14 @@ def _is_valid_url(_instance, _attribute: attr.Attribute, value: str) -> NoReturn pr = urlparse(value) if not pr.scheme or not pr.netloc: - raise ValueError(f"Invalid URL `{value}`.") + raise ValueError(f"Invalid URL: `{value}`.") -def _cache_converter(value: Optional[Union[str, Path]]) -> Cache: +def _cache_converter(value: Optional[Union[str, Path, Cache]]) -> Cache: """Convert ``value`` to :class:`omnipath._core.cache.Cache`.""" + if isinstance(value, Cache): + return value + if value is None: return MemoryCache() @@ -129,36 +132,52 @@ class Options: on_setattr=attr.setters.validate, ) - def _create_config(self): + def _create_config(self, section: Optional[str] = None): + section = self.url if section is None else section + _is_valid_url(None, None, section) config = configparser.ConfigParser() # do not save the password - config[self.url] = { + config[section] = { "license": self.license.value, "cache_dir": str(self.cache.path), + "autoload": self.autoload, + "convert_dtypes": self.convert_dtypes, "num_retries": self.num_retries, "timeout": self.timeout, "chunk_size": self.chunk_size, "progress_bar": self.progress_bar, - "autoload": self.autoload, - "convert_dtypes": self.convert_dtypes, } return config @classmethod - def from_config(cls) -> "Options": - """Return the options from a configuration file.""" + def from_config(cls, section: Optional[str] = None) -> "Options": + """ + Return the options from a configuration file. + + Parameters + ---------- + section + Section of the `.ini` file from which to create the options. It corresponds to the URL of the server. + If `None`, use default URL. + + Returns + ------- + :class:`omnipath._cores.utils.Options` + The options. + """ if not cls.config_path.is_file(): return cls().write() - section = DEFAULT_OPTIONS.url - - config = configparser.ConfigParser() + config = configparser.ConfigParser(default_section=DEFAULT_OPTIONS.url) config.read(cls.config_path) + section = DEFAULT_OPTIONS.url if section is None else section + _is_valid_url(None, None, section) + _ = config.get(section, "cache_dir") + cache = config.get(section, "cache_dir", fallback=DEFAULT_OPTIONS.cache_dir) - if cache == "None": - cache = None + cache = None if cache == "None" else cache return cls( url=section, @@ -208,20 +227,20 @@ def from_options(cls, options: "Options", **kwargs) -> "Options": ) kwargs = {k: v for k, v in kwargs.items() if hasattr(options, k)} + return cls(**{**options.__dict__, **kwargs}) - def write(self) -> NoReturn: + def write(self, section: Optional[str] = None) -> NoReturn: """Write the current options to a configuration file.""" self.config_path.parent.mkdir(exist_ok=True) - config = self._create_config() with open(self.config_path, "w") as fout: - config.write(fout) + self._create_config(section).write(fout) return self def __enter__(self): - return self + return self.from_options(self) def __exit__(self, exc_type, exc_val, exc_tb): pass diff --git a/omnipath/constants/_constants.py b/omnipath/constants/_constants.py index 451da05..5f1d4e0 100644 --- a/omnipath/constants/_constants.py +++ b/omnipath/constants/_constants.py @@ -71,10 +71,11 @@ def s(self) -> str: @unique +@document_enum class License(PrettyEnumMixin): """License types.""" - ACADEMIC = "academic" # doc: academic license. + ACADEMIC = "academic" # doc: Academic license. COMMERCIAL = "commercial" # doc: Commercial license. diff --git a/omnipath/constants/_pkg_constants.py b/omnipath/constants/_pkg_constants.py index 6916c70..64142fb 100644 --- a/omnipath/constants/_pkg_constants.py +++ b/omnipath/constants/_pkg_constants.py @@ -1,8 +1,14 @@ +from os import environ from pathlib import Path from omnipath.constants import License, Organism from omnipath.constants._constants import PrettyEnumMixin +try: + from typing import final +except ImportError: + from typing_extensions import final # noqa: F401 + class DEFAULT_FIELD(PrettyEnumMixin): """Default values for ``field`` parameter.""" @@ -31,7 +37,9 @@ class DEFAULT_OPTIONS: cache_dir: Path = Path.home() / ".cache" / "omnipathdb" mem_cache = None progress_bar: bool = True - autoload: bool = True + autoload: bool = ( + environ.get("OMNIPATH_AUTOLOAD", "") == "" + ) # this is done because for testing purposes convert_dtypes: bool = True @@ -40,7 +48,7 @@ class Endpoint(PrettyEnumMixin): RESOURCES = "resources" ABOUT = "about" - INFO = "info" + INFO = "info" # not used # TODO: refactor me @@ -51,6 +59,8 @@ class Key(PrettyEnumMixin): # noqa: D101 DATASETS = "datasets" LICENSE = "license" QUERIES = "queries" + FIELDS = "fields" + PASSWORD = "password" INTERCELL_SUMMARY = "intercell_summary" GENERIC_CATEGORIES = "generic_categories" diff --git a/requirements.txt b/requirements.txt index 4c60275..d7355f6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,5 +6,6 @@ inflect>=4.1.0 pandas>=1.1.4 requests>=2.24.0 tqdm>=4.51.0 +typing_extensions>=3.7.4.3 urllib3>=1.25.11 wrapt>=1.12.0 diff --git a/setup.py b/setup.py index 1dced1d..5b0adc3 100644 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ description=Path("README.rst").read_text("utf-8").splitlines()[3], long_description=Path("README.rst").read_text("utf-8"), description_content_type="text/x-rst; charset=UTF-8", + long_description_content_type="text/x-rst; charset=UTF-8", # links url="https://omnipathdb.org/", download_url="https://github.com/saezlab/omnipath/releases/", @@ -149,14 +150,14 @@ } ), classifiers=[ - "Development Status :: 1 - Planning", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Developers", "Intended Audience :: Science/Research", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", "Natural Language :: English", + "Typing :: Typed", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -165,7 +166,7 @@ # package installation packages=find_packages(), zip_safe=False, - python_required=">=3.6", + python_required=">=3.7", include_package_data=False, # dependency_links = deplinks install_requires=list( diff --git a/tests/_data/import_intercell_result.pickle b/tests/_data/import_intercell_result.pickle new file mode 100644 index 0000000..35f4238 Binary files /dev/null and b/tests/_data/import_intercell_result.pickle differ diff --git a/tests/_data/interactions.pickle b/tests/_data/interactions.pickle new file mode 100644 index 0000000..7d88e15 Binary files /dev/null and b/tests/_data/interactions.pickle differ diff --git a/tests/_data/receivers.pickle b/tests/_data/receivers.pickle new file mode 100644 index 0000000..419c432 Binary files /dev/null and b/tests/_data/receivers.pickle differ diff --git a/tests/_data/transmitters.pickle b/tests/_data/transmitters.pickle new file mode 100644 index 0000000..f8818ea Binary files /dev/null and b/tests/_data/transmitters.pickle differ diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ee78a99 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,172 @@ +from io import StringIO +from copy import deepcopy +from shutil import copy +from pathlib import Path +from collections import defaultdict +import json +import pickle + +import pytest + +import pandas as pd + +from omnipath.constants import InteractionDataset +from omnipath._core.cache._cache import MemoryCache +from omnipath._core.query._query import QueryType +from omnipath._core.utils._options import Options +from omnipath.constants._pkg_constants import Key +from omnipath._core.downloader._downloader import Downloader + + +@pytest.fixture(scope="function") +def options() -> "Options": + opt = Options.from_config() + opt.cache = None + opt.progress_bar = False + return opt + + +@pytest.fixture(scope="function") +def config_backup(tmpdir): + copy(Options.config_path, tmpdir / "config.ini") + yield + copy(tmpdir / "config.ini", Options.config_path) + + +@pytest.fixture(scope="function") +def cache_backup(): + import omnipath as op + + cache = deepcopy(op.options.cache) + pb = op.options.progress_bar + op.options.cache = MemoryCache() + op.options.progress_bar = False + yield + op.options.cache = cache + op.options.progress_bar = pb + + +@pytest.fixture(scope="function") +def downloader(options) -> "Downloader": + return Downloader(options) + + +@pytest.fixture(scope="session") +def csv_data() -> bytes: + str_handle = StringIO() + pd.DataFrame({"foo": range(5), "bar": "baz", "quux": 42}).to_csv(str_handle) + + return bytes(str_handle.getvalue(), encoding="utf-8") + + +@pytest.fixture(scope="session") +def tsv_data() -> bytes: + str_handle = StringIO() + pd.DataFrame( + { + "foo": range(5), + "components_genesymbols": "foo", + "quux": 42, + "modification": "bar", + } + ).to_csv(str_handle, sep="\t") + + return bytes(str_handle.getvalue(), encoding="utf-8") + + +@pytest.fixture(scope="session") +def intercell_data() -> bytes: + data = {} + data[Key.PARENT.s] = [42, 1337, 24, 42] + data[Key.CATEGORY.s] = ["foo", "bar", "bar", "foo"] + + return bytes(json.dumps(data), encoding="utf-8") + + +@pytest.fixture(scope="session") +def resources() -> bytes: + data = defaultdict(dict) + data["foo"][Key.QUERIES.s] = { + QueryType.INTERCELL.endpoint: {Key.GENERIC_CATEGORIES.s: ["42"]} + } + data["bar"][Key.QUERIES.s] = { + QueryType.INTERCELL.endpoint: {Key.GENERIC_CATEGORIES.s: ["42", "13"]} + } + data["baz"][Key.QUERIES.s] = { + QueryType.INTERCELL.endpoint: {Key.GENERIC_CATEGORIES.s: ["24"]} + } + data["quux"][Key.QUERIES.s] = { + QueryType.ENZSUB.endpoint: {Key.GENERIC_CATEGORIES.s: ["24"]} + } + + return bytes(json.dumps(data), encoding="utf-8") + + +@pytest.fixture(scope="session") +def interaction_resources() -> bytes: + data = defaultdict(dict) + for i, d in enumerate(InteractionDataset): + data[f"d_{i}"][Key.QUERIES.s] = { + QueryType.INTERACTIONS.endpoint: {Key.DATASETS.s: [d.value]} + } + + return bytes(json.dumps(data), encoding="utf-8") + + +@pytest.fixture(scope="session") +def complexes() -> pd.DataFrame: + return pd.DataFrame( + { + "components_genesymbols": [ + "foo", + "bar_baz_quux", + "baz_bar", + "bar_quux_foo", + ], + "dummy": 42, + } + ) + + +@pytest.fixture(scope="session") +def interactions_data() -> bytes: + str_handle = StringIO() + with open(Path("tests") / "_data" / "interactions.pickle", "rb") as fin: + data: pd.DataFrame = pickle.load(fin) + + data.to_csv(str_handle, sep="\t", index=False) + + return bytes(str_handle.getvalue(), encoding="utf-8") + + +@pytest.fixture(scope="session") +def transmitters_data() -> bytes: + str_handle = StringIO() + with open(Path("tests") / "_data" / "transmitters.pickle", "rb") as fin: + data: pd.DataFrame = pickle.load(fin) + + data.to_csv(str_handle, sep="\t", index=False) + + return bytes(str_handle.getvalue(), encoding="utf-8") + + +@pytest.fixture(scope="session") +def receivers_data() -> bytes: + str_handle = StringIO() + with open(Path("tests") / "_data" / "receivers.pickle", "rb") as fin: + data: pd.DataFrame = pickle.load(fin) + + data.to_csv(str_handle, sep="\t", index=False) + + return bytes(str_handle.getvalue(), encoding="utf-8") + + +@pytest.fixture(scope="session") +def import_intercell_result() -> pd.DataFrame: + with open(Path("tests") / "_data" / "import_intercell_result.pickle", "rb") as fin: + return pickle.load(fin) + + +@pytest.fixture(scope="session") +def string_series() -> pd.Series: + return pd.Series(["foo:123", "bar:45;baz", None, "baz:67;bar:67", "foo;foo;foo"]) diff --git a/tests/test_cache.py b/tests/test_cache.py new file mode 100644 index 0000000..c6ef55a --- /dev/null +++ b/tests/test_cache.py @@ -0,0 +1,99 @@ +from copy import copy, deepcopy +from pathlib import Path + +import pytest + +from omnipath import options, clear_cache +from omnipath._core.cache._cache import FileCache, MemoryCache + + +def test_clear_cache_high_lvl(cache_backup): + options.cache["foo"] = 42 + assert len(options.cache) == 1 + assert options.cache["foo"] == 42 + + clear_cache() + + assert len(options.cache) == 0 + + +class TestMemoryCache: + def test_str_repr(self): + mc = MemoryCache() + + assert str(mc) == f"<{mc.__class__.__name__}[size={len(mc)}]>" + assert repr(mc) == f"<{mc.__class__.__name__}[size={len(mc)}]>" + + def test_path_is_None(self): + mc = MemoryCache() + assert mc.path is None + + def test_copy_does_nothing(self): + mc = MemoryCache() + + assert mc is mc.copy() + assert mc is copy(mc) + + def test_deepcopy_work(self): + mc = MemoryCache() + + assert mc is not deepcopy(mc) + + def test_cache_works(self): + mc = MemoryCache() + sentinel = object() + + mc["foo"] = sentinel + + assert len(mc) == 1 + assert mc["foo"] is sentinel + + mc.clear() + + assert len(mc) == 0 + + +class TestPickleCache: + def test_invalid_path(self): + with pytest.raises(TypeError): + FileCache(42) + + def test_path(self, tmpdir): + fc = FileCache(Path(tmpdir)) + + assert isinstance(fc.path, Path) + assert str(fc.path) == str(tmpdir) + + def test_str_repr(self, tmpdir): + fc = FileCache(Path(tmpdir)) + + assert ( + str(fc) + == f"<{fc.__class__.__name__}[size={len(fc)}, path={str(tmpdir)!r}]>" + ) + assert ( + repr(fc) + == f"<{fc.__class__.__name__}[size={len(fc)}, path={str(tmpdir)!r}]>" + ) + + def test_cache_works(self, tmpdir): + fc = FileCache(Path(tmpdir)) + sentinel = object() + + assert "foo" not in fc + fc["foo"] = 42 + fc["bar.pickle"] = sentinel + + assert "foo" in fc + assert "foo.pickle" in fc + assert fc["bar.pickle"] is not sentinel + + def test_clear_works(self, tmpdir): + fc = FileCache(Path(tmpdir)) + fc["foo"] = 42 + assert Path(fc.path).exists() + + fc.clear() + + assert len(fc) == 0 + assert not Path(tmpdir).exists() diff --git a/tests/test_downloader.py b/tests/test_downloader.py new file mode 100644 index 0000000..66d6679 --- /dev/null +++ b/tests/test_downloader.py @@ -0,0 +1,195 @@ +from io import BytesIO, StringIO +from urllib.parse import urljoin +import logging + +import pytest + +import numpy as np +import pandas as pd + +from omnipath import options as opt +from omnipath._core.utils._options import Options +from omnipath.constants._pkg_constants import UNKNOWN_SERVER_VERSION, Endpoint +from omnipath._core.downloader._downloader import Downloader, _get_server_version + + +class TestDownloader: + def test_options_wrong_type(self): + with pytest.raises(TypeError): + Downloader("foobar") + + def test_str_repr(self, options: Options): + d = Downloader(options) + + assert str(d) == f"<{d.__class__.__name__}[options={options}]>" + assert repr(d) == f"<{d.__class__.__name__}[options={options}]>" + + def test_initialize_local_options(self, options: Options): + options.password = "foo" + options.timeout = 1337 + d = Downloader(options) + + assert d._options is not options + assert str(d._options) == str(options) + assert str(d._options) != str(opt) + + options.password = "bar" + assert d._options.password == "foo" + + def test_initialize_global_options(self): + d = Downloader() + + assert d._options is not opt + assert str(d._options) == str(opt) + + def test_resources_cached_values(self, downloader: Downloader, requests_mock): + data = {"foo": "bar", "42": 1337} + requests_mock.register_uri( + "GET", urljoin(downloader._options.url, Endpoint.RESOURCES.s), json=data + ) + + assert downloader.resources == data + assert requests_mock.called_once + + assert downloader.resources == data + assert requests_mock.called_once + + def test_resources_no_cached_values(self, downloader: Downloader, requests_mock): + data = {"foo": "bar", "42": 1337} + requests_mock.register_uri( + "GET", urljoin(downloader._options.url, Endpoint.RESOURCES.s), json=data + ) + + assert downloader.resources == data + assert requests_mock.called_once + + downloader._options.cache.clear() + + assert downloader.resources == data + assert len(requests_mock.request_history) == 2 + + def test_maybe_download_not_callable(self, downloader: Downloader): + with pytest.raises(TypeError): + downloader.maybe_download("foo", callback=None) + + def test_maybe_download_wrong_callable( + self, downloader: Downloader, requests_mock, csv_data: bytes + ): + url = urljoin(downloader._options.url, "foobar") + requests_mock.register_uri("GET", url, content=csv_data) + + with pytest.raises(ValueError, match=r"Expected object or value"): + downloader.maybe_download(url, callback=pd.read_json) + + def test_maybe_download_passes_params( + self, downloader: Downloader, requests_mock, csv_data: bytes + ): + csv_url = urljoin(downloader._options.url, "foobar/?format=csv") + csv_df = pd.read_csv(BytesIO(csv_data)) + json_url = urljoin(downloader._options.url, "foobar/?format=json") + json_handle = StringIO() + csv_df.to_json(json_handle) + + requests_mock.register_uri("GET", csv_url, content=csv_data) + requests_mock.register_uri( + "GET", json_url, content=bytes(json_handle.getvalue(), encoding="utf-8") + ) + + res1 = downloader.maybe_download(csv_url, callback=pd.read_csv) + res2 = downloader.maybe_download(csv_url, callback=pd.read_csv) + + assert res1 is res2 + assert requests_mock.called_once + np.testing.assert_array_equal(res1.index, csv_df.index) + np.testing.assert_array_equal(res1.columns, csv_df.columns) + np.testing.assert_array_equal(res1.values, csv_df.values) + + res1 = downloader.maybe_download(json_url, callback=pd.read_json) + res2 = downloader.maybe_download(json_url, callback=pd.read_json) + + assert res1 is res2 + assert len(requests_mock.request_history) == 2 + np.testing.assert_array_equal(res1.index, csv_df.index) + np.testing.assert_array_equal(res1.columns, csv_df.columns) + np.testing.assert_array_equal(res1.values, csv_df.values) + + def test_maybe_download_no_cache( + self, downloader: Downloader, requests_mock, csv_data: bytes + ): + url = urljoin(downloader._options.url, "foobar") + requests_mock.register_uri("GET", url, content=csv_data) + + res1 = downloader.maybe_download(url, callback=pd.read_csv) + downloader._options.cache.clear() + res2 = downloader.maybe_download(url, callback=pd.read_csv) + + assert res1 is not res2 + assert len(requests_mock.request_history) == 2 + np.testing.assert_array_equal(res1.index, res2.index) + np.testing.assert_array_equal(res1.columns, res2.columns) + np.testing.assert_array_equal(res1.values, res2.values) + + def test_maybe_download_is_not_final( + self, downloader: Downloader, requests_mock, csv_data: bytes + ): + endpoint = "barbaz" + url = urljoin(downloader._options.url, endpoint) + requests_mock.register_uri("GET", url, content=csv_data) + csv_df = pd.read_csv(BytesIO(csv_data)) + + res = downloader.maybe_download(endpoint, callback=pd.read_csv, is_final=False) + + assert requests_mock.called_once + np.testing.assert_array_equal(res.index, csv_df.index) + np.testing.assert_array_equal(res.columns, csv_df.columns) + np.testing.assert_array_equal(res.values, csv_df.values) + + def test_get_server_version_not_decodable( + self, options: Options, requests_mock, caplog + ): + url = urljoin(options.url, Endpoint.ABOUT.s) + options.autoload = True + requests_mock.register_uri( + "GET", f"{url}?format=text", content=bytes("foobarbaz", encoding="utf-8") + ) + + with caplog.at_level(logging.DEBUG): + version = _get_server_version(options) + + assert requests_mock.called_once + assert ( + "Unable to get server version. Reason: `list index out of range`" + in caplog.text + ) + assert version == UNKNOWN_SERVER_VERSION + + def test_get_server_version_no_autoload( + self, options: Options, requests_mock, caplog + ): + url = urljoin(options.url, Endpoint.ABOUT.s) + options.autoload = False + requests_mock.register_uri("GET", f"{url}?format=text", text="foobarbaz") + + with caplog.at_level(logging.DEBUG): + version = _get_server_version(options) + + assert not requests_mock.called_once + assert ( + "Unable to get server version. Reason: `Autoload is disabled.`" + in caplog.text + ) + assert version == UNKNOWN_SERVER_VERSION + + def test_get_server_version(self, options: Options, requests_mock): + url = urljoin(options.url, Endpoint.ABOUT.s) + options.autoload = True + requests_mock.register_uri( + "GET", + f"{url}?format=text", + content=bytes("foo bar baz\nversion: 42.1337.00", encoding="utf-8"), + ) + + version = _get_server_version(options) + + assert requests_mock.called_once + assert version == "42.1337.00" diff --git a/tests/test_interactions.py b/tests/test_interactions.py new file mode 100644 index 0000000..ca7ece8 --- /dev/null +++ b/tests/test_interactions.py @@ -0,0 +1,235 @@ +from urllib.parse import urljoin +import json + +import pytest + +import numpy as np +import pandas as pd + +from omnipath import options +from omnipath.constants import Organism, InteractionDataset +from omnipath._core.requests import Intercell +from omnipath.constants._pkg_constants import Key, Endpoint +from omnipath._core.requests.interactions._utils import ( + get_signed_ptms, + import_intercell_network, +) +from omnipath._core.requests.interactions._interactions import ( + TFmiRNA, + Dorothea, + OmniPath, + TFtarget, + KinaseExtra, + LigRecExtra, + PathwayExtra, + AllInteractions, + Transcriptional, + miRNA, + lncRNAmRNA, +) + + +class TestInteractions: + def test_all_excluded_excluded(self): + with pytest.raises( + ValueError, match=r"After excluding `\d+` datasets, none were left." + ): + AllInteractions.get(exclude=list(InteractionDataset)) + + def test_invalid_excluded_datasets(self): + with pytest.raises( + ValueError, match=r"Invalid value `foo` for `InteractionDataset`." + ): + AllInteractions.get(exclude="foo") + + @pytest.mark.parametrize( + "interaction", + [ + PathwayExtra, + KinaseExtra, + LigRecExtra, + miRNA, + TFmiRNA, + lncRNAmRNA, + Dorothea, + TFtarget, + OmniPath, + ], + ) + def test_resources( + self, cache_backup, interaction, interaction_resources: bytes, requests_mock + ): + url = urljoin(options.url, Endpoint.RESOURCES.s) + data = json.loads(interaction_resources) + requests_mock.register_uri( + "GET", f"{url}?format=json", content=interaction_resources + ) + + resources = interaction.resources() + for resource in resources: + assert { + InteractionDataset(d) + for d in data[resource][Key.QUERIES.s][ + interaction._query_type.endpoint + ][Key.DATASETS.s] + } & interaction()._datasets + assert requests_mock.called_once + + def test_invalid_organism(self): + with pytest.raises( + ValueError, match=r"Invalid value `foo` for `Organism`. Valid options are:" + ): + AllInteractions.get(**{Key.ORGANISM.s: "foo"}) + + @pytest.mark.parametrize("organism", list(Organism)) + def test_valid_organism( + self, cache_backup, organism, requests_mock, interaction_resources + ): + url = urljoin(options.url, AllInteractions._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?fields=curation_effort%2Creferences%2Csources%2Ctype&" + f"format=tsv&license=academic&organism={organism.code}", + content=interaction_resources, + ) + + AllInteractions.get(organism=organism, format="tsv") + AllInteractions.get(organism=organism.value, format="tsv") + assert requests_mock.called_once + + def test_dorothea_params(self): + params = Dorothea.params() + + assert "dorothea_levels" in params + assert "dorothea_methods" in params + assert "tfregulons_levels" not in params + assert "tfregulons_methods" not in params + assert Key.DATASETS.s not in params + + def test_tftarget_params(self): + params = TFtarget.params() + + assert "dorothea_levels" not in params + assert "dorothea_methods" not in params + assert "tfregulons_levels" in params + assert "tfregulons_methods" in params + assert Key.DATASETS.s not in params + + @pytest.mark.parametrize( + "interaction", [OmniPath, Transcriptional, AllInteractions] + ) + def test_transcriptional_params(self, interaction): + params = interaction.params() + + assert "dorothea_levels" in params + assert "dorothea_methods" in params + assert "tfregulons_levels" in params + assert "tfregulons_methods" in params + assert Key.DATASETS.s not in params + + @pytest.mark.parametrize( + "interaction", + [PathwayExtra, KinaseExtra, LigRecExtra, miRNA, TFmiRNA, lncRNAmRNA], + ) + def test_rest_params(self, interaction): + params = interaction.params() + + assert "dorothea_levels" not in params + assert "dorothea_methods" not in params + assert "tfregulons_levels" not in params + assert "tfregulons_methods" not in params + assert Key.DATASETS.s not in params + + +class TestUtils: + def test_get_signed_ptms_wrong_ptms_type(self): + with pytest.raises(TypeError, match=r"Expected `ptms`"): + get_signed_ptms(42, pd.DataFrame()) + + def test_get_signed_ptms_wrong_interactions_type(self): + with pytest.raises(TypeError, match=r"Expected `interactions`"): + get_signed_ptms(pd.DataFrame(), 42) + + def test_get_signed_ptms(self): + ptms = pd.DataFrame( + {"enzyme": ["alpha", "beta", "gamma"], "substrate": [0, 1, 0], "foo": 42} + ) + interactions = pd.DataFrame( + { + "source": ["gamma", "beta", "delta"], + "target": [0, 0, 1], + "is_stimulation": True, + "is_inhibition": False, + "bar": 1337, + } + ) + expected = pd.merge( + ptms, + interactions[["source", "target", "is_stimulation", "is_inhibition"]], + left_on=["enzyme", "substrate"], + right_on=["source", "target"], + how="left", + ) + + res = get_signed_ptms(ptms, interactions) + + np.testing.assert_array_equal(res.index, expected.index) + np.testing.assert_array_equal(res.columns, expected.columns) + + np.testing.assert_array_equal(pd.isnull(res), pd.isnull(expected)) + np.testing.assert_array_equal( + res.values[~pd.isnull(res)], expected.values[~pd.isnull(expected)] + ) + + def test_import_intercell_network( + self, + cache_backup, + requests_mock, + interactions_data: bytes, + transmitters_data: bytes, + receivers_data: bytes, + import_intercell_result: pd.DataFrame, + ): + interactions_url = urljoin(options.url, AllInteractions._query_type.endpoint) + intercell_url = urljoin(options.url, Intercell._query_type.endpoint) + + # interactions + requests_mock.register_uri( + "GET", + f"{interactions_url}?datasets=dorothea&dorothea_levels=A&fields=curation_effort%2C" + f"references%2Csources&format=tsv&license=academic", + content=interactions_data, + ) + # transmitter + requests_mock.register_uri( + "GET", + f"{intercell_url}?categories=ligand&causality=trans&format=tsv&license=academic&scope=generic", + content=transmitters_data, + ) + # receiver + requests_mock.register_uri( + "GET", + f"{intercell_url}?categories=receptor&causality=rec&format=tsv&license=academic&scope=generic", + content=receivers_data, + ) + + res = import_intercell_network( + transmitter_params={"categories": "ligand"}, + interactions_params={"datasets": "dorothea", "dorothea_levels": "A"}, + receiver_params={"categories": "receptor"}, + ) + + assert isinstance(res, pd.DataFrame) + assert res.shape == (31, 46) + + np.testing.assert_array_equal(res.index, import_intercell_result.index) + np.testing.assert_array_equal(res.columns, import_intercell_result.columns) + np.testing.assert_array_equal(res.dtypes, import_intercell_result.dtypes) + np.testing.assert_array_equal( + pd.isnull(res), pd.isnull(import_intercell_result) + ) + np.testing.assert_array_equal( + res.values[~pd.isnull(res)], + import_intercell_result.values[~pd.isnull(import_intercell_result)], + ) + assert len(requests_mock.request_history) == 3 diff --git a/tests/test_omnipath.py b/tests/test_omnipath.py deleted file mode 100644 index 638acbe..0000000 --- a/tests/test_omnipath.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - -from omnipath import options -from omnipath.constants import License - - -class TestOptions: - def test_options_invalid_url_type(self): - with pytest.raises(TypeError): - options.url = 42 - - def test_options_invalid_url(self): - with pytest.raises(ValueError): - options.url = "foo" - - def test_options_url_localhost(self): - options.url = "https://localhost" - - assert options.url == "https://localhost" - - def test_options_invalid_license(self): - with pytest.raises(ValueError): - options.license = "foo" - - def test_options_valid_license(self): - options.license = "commercial" - - assert isinstance(options.license, License) - assert options.license == License.COMMERCIAL diff --git a/tests/test_options.py b/tests/test_options.py new file mode 100644 index 0000000..52988b4 --- /dev/null +++ b/tests/test_options.py @@ -0,0 +1,145 @@ +from os import remove +from typing import Optional +from pathlib import Path +from configparser import NoSectionError + +import pytest + +from omnipath.constants import License +from omnipath._core.utils._options import Options +from omnipath.constants._pkg_constants import DEFAULT_OPTIONS + + +class TestOptions: + def test_invalid_url_type(self, options: Options): + with pytest.raises(TypeError): + options.url = 42 + + def test_invalid_url(self, options: Options): + with pytest.raises(ValueError): + options.url = "foo" + + def test_invalid_license(self, options: Options): + with pytest.raises(ValueError): + options.license = "foo" + + def test_invalid_cache_type(self, options: Options): + with pytest.raises(TypeError): + options.cache = 42 + + def test_invalid_password_type(self, options: Options): + with pytest.raises(TypeError): + options.password = 42 + + def test_invalid_num_retries(self, options: Options): + with pytest.raises(ValueError): + options.num_retries = -1 + + def test_invalid_timeout(self, options: Options): + with pytest.raises(ValueError): + options.timeout = 0 + + def test_invalid_chunk_size(self, options: Options): + with pytest.raises(ValueError): + options.chunk_size = 0 + + def test_from_options_invalid_type(self): + with pytest.raises(TypeError): + Options.from_options("foo") + + def test_url_localhost(self, options: Options): + options.url = "https://localhost" + + assert options.url == "https://localhost" + + @pytest.mark.parametrize("license", list(License)) + def test_valid_license(self, options: Options, license: License): + options.license = license.value + + assert isinstance(options.license, License) + assert options.license == license + + @pytest.mark.parametrize("pwd", ["foo", None]) + def test_password(self, options: Options, pwd: Optional[str]): + options.password = pwd + + assert options.password == pwd + + def test_from_options(self, options: Options): + new_opt = Options.from_options(options) + + for k, v in options.__dict__.items(): + assert getattr(new_opt, k) == v + + def test_from_options_new_values(self, options: Options): + new_opt = Options.from_options( + options, autoload=not options.autoload, num_retries=0 + ) + + for k, v in options.__dict__.items(): + if k not in ("autoload", "num_retries"): + assert getattr(new_opt, k) == v + + assert new_opt.autoload != options.autoload + assert new_opt.num_retries == 0 + + def test_from_config_no_file(self, config_backup): + if Path(Options.config_path).exists(): + remove(Options.config_path) + + new_opt = Options.from_config() + + for k, v in DEFAULT_OPTIONS.__dict__.items(): + if hasattr(new_opt, k) and not k.startswith("_"): + assert getattr(new_opt, k) == v + + def test_from_config_section_is_not_url(self): + with pytest.raises(NoSectionError, match=r"No section: 'http://foo.bar'"): + Options.from_config("http://foo.bar") + + def test_write_config(self, options: Options, config_backup): + options.timeout = 1337 + options.license = License.COMMERCIAL + options.password = "foobarbaz" + options.write() + + new_opt = Options.from_config() + for k, v in options.__dict__.items(): + if k == "cache": + assert type(new_opt.cache) == type(options.cache) # noqa: E721 + elif k == "password": + # don't store the password in the file + assert getattr(new_opt, k) is None + elif k not in ("timeout", "license"): + assert getattr(new_opt, k) == v + + assert new_opt.timeout == 1337 + assert new_opt.license == License.COMMERCIAL + + def test_write_new_section(self, options: Options, config_backup): + options.timeout = 42 + options.write("https://foo.bar") + + new_opt = Options.from_config("https://foo.bar") + assert options is not new_opt + for k, v in options.__dict__.items(): + if k == "url": + assert v == options.url + assert new_opt.url == "https://foo.bar" + elif k == "cache": + assert type(new_opt.cache) == type(options.cache) # noqa: E721 + else: + assert getattr(new_opt, k) == v + + def test_write_new_section_not_url(self, options: Options, config_backup): + with pytest.raises(ValueError, match=r"Invalid URL: `foobar`."): + options.write("foobar") + + def test_contextmanager(self, options: Options): + with options as new_opt: + assert options is not new_opt + for k, v in options.__dict__.items(): + if k == "cache": + assert type(new_opt.cache) == type(options.cache) # noqa: E721 + else: + assert getattr(new_opt, k) == v diff --git a/tests/test_query.py b/tests/test_query.py new file mode 100644 index 0000000..9d0c204 --- /dev/null +++ b/tests/test_query.py @@ -0,0 +1,170 @@ +from typing import _GenericAlias +from collections import defaultdict + +import pytest + +from omnipath._core.query._query import ( + Query, + QueryType, + EnzsubQuery, + ComplexesQuery, + IntercellQuery, + AnnotationsQuery, + InteractionsQuery, + _get_synonyms, +) +from omnipath._core.query._query_validator import ( + EnzsubValidator, + ComplexesValidator, + IntercellValidator, + AnnotationsValidator, + InteractionsValidator, + _to_string_set, +) + + +class TestUtils: + def test_get_synonyms_wrong_type(self): + with pytest.raises(TypeError): + _get_synonyms(42) + + def test_get_synonyms_from_s2p(self): + res = _get_synonyms("cat") + + assert len(res) == 2 + assert res == ("cat", "cats") + + def test_get_synonyms_from_p2s(self): + res = _get_synonyms("dogs") + + assert len(res) == 2 + assert res == ("dog", "dogs") + + def test_to_string_set_string(self): + assert {"foo"} == _to_string_set("foo") + + def test_to_string_set_int(self): + assert {"42"} == _to_string_set(42) + + def test_to_string_set_sequence(self): + assert {"foo", "42"} == _to_string_set(["foo", 42]) + + +class TestValidator: + @pytest.mark.parametrize( + "validator", + [ + EnzsubValidator, + InteractionsValidator, + ComplexesValidator, + AnnotationsValidator, + IntercellValidator, + ], + ) + def test_validator_no_server_access(self, validator): + for value in list(validator): + v = validator(value) + + assert v.valid is None + assert v.doc is None + + assert v(None) is None + assert v("foo") == {"foo"} + assert v(42) == {"42"} + assert v(True) == {"1"} + assert v(False) == {"0"} + assert v(["foo", "foo"]) == {"foo"} + assert v(["foo", 42]) == {"foo", "42"} + assert v({"foo", "bar", "baz"}) == {"foo", "bar", "baz"} + + assert issubclass(type(v.annotation), (_GenericAlias, type)) + + +class TestQuery: + @pytest.mark.parametrize( + "query,validator", + zip( + [ + EnzsubQuery, + InteractionsQuery, + ComplexesQuery, + AnnotationsQuery, + IntercellQuery, + ], + [ + EnzsubValidator, + InteractionsValidator, + ComplexesValidator, + AnnotationsValidator, + IntercellValidator, + ], + ), + ) + def test_query_correct_validator(self, query, validator): + assert query.__validator__ == validator + + def test_query_endpoint(self): + for q in list(QueryType): + q = QueryType(q) + + assert issubclass(q.value, Query) + assert q.endpoint == q.name.lower() + + @pytest.mark.parametrize( + "query,validator", + zip( + [ + EnzsubQuery, + InteractionsQuery, + ComplexesQuery, + AnnotationsQuery, + IntercellQuery, + ], + [ + EnzsubValidator, + InteractionsValidator, + ComplexesValidator, + AnnotationsValidator, + IntercellValidator, + ], + ), + ) + def test_query_delegation(self, query, validator, mocker): + call_spy = mocker.spy(validator, "__call__") + + qdb = query("databases") + _ = qdb("foo") + + call_spy.assert_called_once_with( + getattr(qdb.__validator__, qdb._query_name), "foo" + ) + assert call_spy.spy_return == {"foo"} + assert qdb.doc is None + + for attr in ("valid", "annotation", "doc"): + m = mocker.patch.object( + validator, attr, new_callable=mocker.PropertyMock, return_value="foo" + ) + assert getattr(qdb, attr) == "foo" + + m.assert_called_once() + + @pytest.mark.parametrize( + "query", + [ + EnzsubQuery, + InteractionsQuery, + ComplexesQuery, + AnnotationsQuery, + IntercellQuery, + ], + ) + def test_query_synonym(self, query): + mapper = defaultdict(list) + for v in list(query): + name = "_".join(v.name.split("_")[:-1]) + mapper[name].append(v.value) + + for vs in mapper.values(): + assert len(vs) == 2 + assert len({query(v).param for v in vs}) diff --git a/tests/test_requests.py b/tests/test_requests.py new file mode 100644 index 0000000..7cedc52 --- /dev/null +++ b/tests/test_requests.py @@ -0,0 +1,369 @@ +from io import StringIO +from typing import Iterable, _GenericAlias +from urllib.parse import urljoin +import json +import logging + +import pytest + +from pandas.api.types import is_object_dtype, is_categorical_dtype +import numpy as np +import pandas as pd + +from omnipath import options +from omnipath.requests import Enzsub, Complexes, Intercell, Annotations +from omnipath._core.query._query import EnzsubQuery +from omnipath._core.requests._utils import _split_unique_join, _strip_resource_label +from omnipath.constants._pkg_constants import Key, Endpoint + + +class TestEnzsub: + def test_str_repr(self): + assert str(Enzsub()) == f"<{Enzsub().__class__.__name__}>" + assert repr(Enzsub()) == f"<{Enzsub().__class__.__name__}>" + + def test_params_no_org_genesymbol(self): + params = Enzsub.params() + + assert Key.ORGANISM.value not in params + assert Key.GENESYMBOLS.value not in params + + for k, valid in params.items(): + if isinstance(valid, Iterable): + np.testing.assert_array_equal( + list(set(valid)), list(set(EnzsubQuery(k).valid)) + ) + else: + assert valid == EnzsubQuery(k).valid + + def test_resources(self, cache_backup, requests_mock, resources: bytes): + url = urljoin(options.url, Endpoint.RESOURCES.s) + requests_mock.register_uri("GET", f"{url}?format=json", content=resources) + + res = Enzsub.resources() + + assert res == ("quux",) + assert requests_mock.called_once + + def test_invalid_params(self): + with pytest.raises(ValueError, match=r"Invalid value `foo` for `EnzsubQuery`."): + Enzsub.get(foo="bar") + + def test_invalid_license(self): + with pytest.raises(ValueError, match=r"Invalid value `bar` for `License`."): + Enzsub.get(license="bar") + + def test_invalid_format(self): + with pytest.raises(ValueError, match=r"Invalid value `bar` for `Format`."): + Enzsub.get(format="bar") + + def test_valid_params(self, cache_backup, requests_mock, tsv_data: bytes, caplog): + url = urljoin(options.url, Enzsub._query_type.endpoint) + df = pd.read_csv(StringIO(tsv_data.decode("utf-8")), sep="\t") + requests_mock.register_uri( + "GET", + f"{url}?fields=curation_effort%2Creferences%2Csources&format=tsv&license=academic", + content=tsv_data, + ) + + with caplog.at_level(logging.WARNING): + res = Enzsub.get(license="academic", format="text") + + assert f"Invalid `{Key.FORMAT.s}='text'`" in caplog.text + np.testing.assert_array_equal(res.index, df.index) + np.testing.assert_array_equal(res.columns, df.columns) + np.testing.assert_array_equal(res.values, df.values) + assert requests_mock.called_once + + def test_annotations(self): + assert set(Enzsub._annotations().keys()) == {e.param for e in EnzsubQuery} + assert all( + isinstance(a, (_GenericAlias, type)) for a in Enzsub._annotations().values() + ) + + def test_docs(self): + assert set(Enzsub._docs().keys()) == {e.param for e in EnzsubQuery} + assert all(d is None for d in Enzsub._docs().values()) + + def test_invalid_organism_does_not_matter( + self, cache_backup, requests_mock, tsv_data: bytes + ): + url = urljoin(options.url, Enzsub._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?fields=curation_effort%2Creferences%2Csources&format=tsv&license=academic", + content=tsv_data, + ) + _ = Enzsub.get(organism="foobarbaz") + + assert requests_mock.called_once + + def test_genesymbols_dont_matter( + self, cache_backup, requests_mock, tsv_data: bytes + ): + url = urljoin(options.url, Enzsub._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?fields=curation_effort%2Creferences%2Csources&format=tsv&license=academic", + content=tsv_data, + ) + _ = Enzsub.get(genesymbol=True) + + assert requests_mock.called_once + + def test_field_injection(self, cache_backup, requests_mock, tsv_data: bytes): + url = urljoin(options.url, Enzsub._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?fields=Alpha%2Cbeta%2Ccuration_effort%2Creferences%2Csources&format=tsv&license=academic", + content=tsv_data, + ) + _ = Enzsub.get(fields=("beta", "Alpha", "Alpha")) + + assert requests_mock.called_once + + def test_no_dtype_conversion(self, cache_backup, requests_mock, tsv_data: bytes): + url = urljoin(options.url, Enzsub._query_type.endpoint) + options.convert_dtypes = False + + requests_mock.register_uri( + "GET", + f"{url}?fields=curation_effort%2Creferences%2Csources&format=tsv&license=academic", + content=tsv_data, + ) + + res = Enzsub.get() + assert is_object_dtype(res["modification"]) + + options.convert_dtypes = True + + res = Enzsub.get() + assert is_categorical_dtype(res["modification"]) + assert requests_mock.called_once + + +class TestIntercell: + def test_resources_wrong_type(self): + with pytest.raises(TypeError): + Intercell.resources(42) + + def test_resources_no_generic_resources(self): + with pytest.raises( + ValueError, match=r"No generic categories have been selected." + ): + Intercell.resources([]) + + def test_resources_no_generic(self, cache_backup, requests_mock, resources: bytes): + url = urljoin(options.url, Endpoint.RESOURCES.s) + requests_mock.register_uri("GET", f"{url}?format=json", content=resources) + + res = Intercell.resources() + + assert res == ("bar", "baz", "foo") + assert requests_mock.called_once + + def test_resources_generic(self, cache_backup, requests_mock, resources: bytes): + url = urljoin(options.url, Endpoint.RESOURCES.s) + requests_mock.register_uri("GET", f"{url}?format=json", content=resources) + + res = Intercell.resources(generic_categories=["42"]) + assert res == ("bar", "foo") + + res = Intercell.resources(generic_categories="24") + assert res == ("baz",) + + res = Intercell.resources(generic_categories="foobarbaz") + assert res == () + assert requests_mock.called_once # caching + + def test_categories(self, cache_backup, requests_mock, intercell_data: bytes): + url = urljoin(options.url, Key.INTERCELL_SUMMARY.s) + data = json.loads(intercell_data) + requests_mock.register_uri("GET", f"{url}?format=json", content=intercell_data) + + res = Intercell.categories() + + assert res == tuple(sorted(set(map(str, data[Key.CATEGORY.s])))) + assert requests_mock.called_once + + def test_generic_categories( + self, cache_backup, requests_mock, intercell_data: bytes + ): + url = urljoin(options.url, Key.INTERCELL_SUMMARY.s) + data = json.loads(intercell_data) + requests_mock.register_uri("GET", f"{url}?format=json", content=intercell_data) + + res = Intercell.generic_categories() + + assert res == tuple(sorted(set(map(str, data[Key.PARENT.s])))) + assert requests_mock.called_once + + def test_password_from_options( + self, cache_backup, requests_mock, intercell_data: bytes + ): + old_pwd = options.password + options.password = "foobar" + + url = urljoin(options.url, Intercell._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?format=tsv&license=academic&password=foobar", + content=intercell_data, + ) + + _ = Intercell.get() + options.password = old_pwd + + assert requests_mock.called_once + + def test_password_from_function_call( + self, cache_backup, requests_mock, intercell_data: bytes + ): + old_pwd = options.password + options.password = "foobar" + + url = urljoin(options.url, Intercell._query_type.endpoint) + requests_mock.register_uri( + "GET", + f"{url}?format=tsv&license=academic&password=bazquux", + content=intercell_data, + ) + + _ = Intercell.get(password="bazquux") + options.password = old_pwd + + assert requests_mock.called_once + + +class TestComplex: + def test_complex_genes_wrong_dtype(self): + with pytest.raises(TypeError): + Complexes.complex_genes("foo", complexes=42) + + def test_comples_genes_empty_complexes(self, caplog): + df = pd.DataFrame() + with caplog.at_level(logging.WARNING): + res = Complexes.complex_genes("foo", complexes=df) + + assert res is df + assert "Complexes are empty" in caplog.text + + def test_complex_genes_no_column(self): + with pytest.raises(KeyError): + Complexes.complex_genes("foo", complexes=pd.DataFrame({"foo": range(10)})) + + def test_complex_genes_no_genes(self): + with pytest.raises(ValueError, match=r"No genes have been selected."): + Complexes.complex_genes([], complexes=None) + + def test_complex_genes_complexes_not_specified( + self, cache_backup, requests_mock, tsv_data: bytes + ): + url = urljoin(options.url, Complexes._query_type.endpoint) + df = pd.read_csv(StringIO(tsv_data.decode("utf-8")), sep="\t") + requests_mock.register_uri("GET", f"{url}?format=tsv", content=tsv_data) + + res = Complexes.complex_genes("fooo") + + np.testing.assert_array_equal(res.columns, df.columns) + assert res.empty + + def test_complexes_complexes_specified(self, complexes: pd.DataFrame): + res = Complexes.complex_genes("foo", complexes=complexes, total_match=False) + + assert isinstance(res, pd.DataFrame) + assert res.shape == (2, 2) + assert set(res.columns) == {"components_genesymbols", "dummy"} + assert all( + any(v in "foo" for v in vs.split("_")) + for vs in res["components_genesymbols"] + ) + + def test_complexes_total_match(self, complexes: pd.DataFrame): + res = Complexes.complex_genes( + ["bar", "baz"], complexes=complexes, total_match=True + ) + + assert res.shape == (1, 2) + assert all( + all(v in ("bar", "baz") for v in vs.split("_")) + for vs in res["components_genesymbols"] + ) + + def test_complexes_no_total_match(self, complexes: pd.DataFrame): + res = Complexes.complex_genes( + ["bar", "baz", "bar"], complexes=complexes, total_match=False + ) + + assert res.shape == (3, 2) + assert all( + any(v in ("bar", "baz") for v in vs.split("_")) + for vs in res["components_genesymbols"] + ) + + +class TestAnnotations: + def test_too_many_proteins_requested(self): + with pytest.raises(ValueError, match=r"Cannot download annotations for"): + Annotations.get([f"foo_{i}" for i in range(601)]) + + def test_params(self): + params = Annotations.params() + assert Key.ORGANISM.value not in params + + def test_genesymbols_matter(self, cache_backup, requests_mock, tsv_data: bytes): + url = urljoin(options.url, Annotations._query_type.endpoint) + requests_mock.register_uri( + "GET", f"{url}?proteins=bar&genesymbols=1&format=tsv", content=tsv_data + ) + df = pd.read_csv(StringIO(tsv_data.decode("utf-8")), sep="\t") + + res = Annotations.get(["bar"], genesymbols=True) + + np.testing.assert_array_equal(res.index, df.index) + np.testing.assert_array_equal(res.columns, df.columns) + np.testing.assert_array_equal(res.values, df.values) + + def test_invalid_organism_does_not_matter( + self, cache_backup, requests_mock, tsv_data: bytes + ): + url = urljoin(options.url, Annotations._query_type.endpoint) + requests_mock.register_uri( + "GET", f"{url}?proteins=foo&format=tsv", content=tsv_data + ) + df = pd.read_csv(StringIO(tsv_data.decode("utf-8")), sep="\t") + + res = Annotations.get(["foo", "foo"], organism="foobarbaz") + + np.testing.assert_array_equal(res.index, df.index) + np.testing.assert_array_equal(res.columns, df.columns) + np.testing.assert_array_equal(res.values, df.values) + + +class TestUtils: + def test_split_unique_join_no_func(self, string_series: pd.Series): + res = _split_unique_join(string_series) + + np.testing.assert_array_equal( + res, pd.Series(["foo:123", "bar:45;baz", None, "bar:67;baz:67", "foo"]) + ) + + def test_split_unique_join_func(self, string_series: pd.Series): + res = _split_unique_join(string_series, func=len) + + np.testing.assert_array_equal(res, pd.Series([1, 2, None, 2, 3], dtype=object)) + + def test_strip_resource_label_no_func(self, string_series: pd.Series): + res = _strip_resource_label(string_series, func=None) + + np.testing.assert_array_equal( + res, pd.Series(["123", "45;baz", None, "67", "foo"]) + ) + + def test_strip_resource_label_func(self): + res = _strip_resource_label( + pd.Series(["abc:123;bcd:123", "aaa:123", "a:1;b:2;c:3"]), + func=lambda row: len(set(row)), + ) + + np.testing.assert_array_equal(res, pd.Series([1, 1, 3])) diff --git a/tox.ini b/tox.ini index 05665b7..356e4f6 100644 --- a/tox.ini +++ b/tox.ini @@ -18,22 +18,89 @@ exclude = max_line_length = 120 filename = *.py +[gh-actions] +python = + 3.7: py37 + 3.8: py38 + 3.9: py39 + +[gh-actions:env] +PLATFORM = + ubuntu-latest: linux + macos-latest: macos + [pytest] python_files = test_*.py testpaths = tests/ xfail_strict = true +requests_mock_case_sensitive = true [tox] min_version=3.20.0 isolated_build = True -envlist = py{36,37,38,39}-{linux,macos} skip_missing_interpreters=true +envlist = + covclean + lint + py{37,38,39}-{linux,macos} + coverage + readme + docs [testenv] platform = linux: linux - macos: osx|darwin + macos: (macos|osx|darwin) deps = - . pytest -commands = pytest + pytest-mock + pytest-cov + requests-mock + numpy +setenv = OMNIPATH_AUTOLOAD = false +usedevelop = true +commands = pytest --cov --cov-append --cov-config={toxinidir}/.coveragerc --ignore docs/ {posargs:-vv} + +[testenv:covclean] +description = Clean coverage files. +deps = coverage +skip_install = True +commands = coverage erase + +[testenv:lint] +description = Perform linting. +basepython = python3.8 +deps = pre-commit>=2.7.1 +skip_install = true +commands = + pre-commit run --all-files --show-diff-on-failure {posargs:} + +[testenv:coverage] +description = Report the coverage difference. +deps = + coverage + diff_cover +skip_install = true +depends = py{37,38,39}-{linux,macos} +parallel_show_output = True +commands = + coverage report --omit="tox/*" + coverage xml --omit="tox/*" -o {toxinidir}/coverage.xml + diff-cover --compare-branch origin/master {toxinidir}/coverage.xml + +[testenv:docs] +description = Build the documentation. +basepython = python3.8 +extras = docs +whitelist_externals = sphinx-build +commands = + sphinx-build --color -b html {toxinidir}/docs/source {toxinidir}/docs/build/html + python -c 'import pathlib; print(f"Documentation is available under:", pathlib.Path(f"{toxinidir}") / "docs" / "build" / "html" / "index.html")' + +[testenv:readme] +description = Check if README renders on PyPI. +basepython = python3.8 +deps = twine >= 1.12.1 +skip_install = true +commands = pip wheel -q -w {envtmpdir}/build --no-deps . + twine check {envtmpdir}/build/*