Skip to content

Commit

Permalink
feat(#1130): cleanup rb namespace by refactoring client API (#1160)
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Mar 30, 2022
1 parent 4dc8c55 commit a0fdd8e
Show file tree
Hide file tree
Showing 23 changed files with 1,142 additions and 1,049 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Expand Up @@ -135,3 +135,6 @@ package-lock.json

# App generated files
src/**/server/static/

# setuptools_scm generated file
src/rubrix/_version.py
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -3,6 +3,7 @@ requires = ["setuptools", "wheel", "setuptools_scm[toml]"]
build-backend = "setuptools.build_meta"

[tool.setuptools_scm]
write_to = "src/rubrix/_version.py"

[tool.pytest.ini_options]
log_format = "%(asctime)s %(name)s %(levelname)s %(message)s"
Expand Down
301 changes: 75 additions & 226 deletions src/rubrix/__init__.py
Expand Up @@ -13,240 +13,89 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""This file reflects the user facing API.
If you want to add something here, remember to add it as normal import in the _TYPE_CHECKING section (for IDEs),
as well as in the `_import_structure` dictionary.
"""
This module contains the interface to access Rubrix's REST API.
"""

import os
import re
from logging import getLogger
from typing import Any, Dict, Iterable, List, Optional, Union

import pandas
import pkg_resources

from rubrix._constants import DEFAULT_API_KEY
from rubrix.client import RubrixClient
from rubrix.client.datasets import (
Dataset,
DatasetForText2Text,
DatasetForTextClassification,
DatasetForTokenClassification,
read_datasets,
read_pandas,
)
from rubrix.client.models import (
BulkResponse,
Record,
Text2TextRecord,
TextClassificationRecord,
TokenAttributions,
TokenClassificationRecord,
)
from rubrix.logging import configure_logging
from rubrix.monitoring.model_monitor import monitor

try:
__version__ = pkg_resources.get_distribution(__name__).version
except pkg_resources.DistributionNotFound:
# package is not installed
pass

try:
from rubrix.server.server import app
except ModuleNotFoundError as ex:
module_name = ex.name

def fallback_app(*args, **kwargs):
raise RuntimeError(
"\n"
f"Cannot start rubrix server. Some dependencies was not found:[{module_name}].\n"
"Please, install missing modules or reinstall rubrix with server extra deps:\n"
"pip install rubrix[server]"
)

app = fallback_app

configure_logging()

_client: Optional[
RubrixClient
] = None # Client will be stored here to pass it through functions


_LOGGER = getLogger(__name__)


def _client_instance() -> RubrixClient:
"""Checks module instance client and init if not initialized."""

global _client
# Calling a by-default-init if it was not called before
if _client is None:
init()
return _client


def init(
api_url: Optional[str] = None,
api_key: Optional[str] = None,
workspace: Optional[str] = None,
timeout: int = 60,
) -> None:
"""Init the python client.
Passing an api_url disables environment variable reading, which will provide
default values.
Args:
api_url: Address of the REST API. If `None` (default) and the env variable ``RUBRIX_API_URL`` is not set,
it will default to `http://localhost:6900`.
api_key: Authentification key for the REST API. If `None` (default) and the env variable ``RUBRIX_API_KEY``
is not set, it will default to `rubrix.apikey`.
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``RUBRIX_WORKSPACE`` is not set, it will default to the private user workspace.
timeout: Wait `timeout` seconds for the connection to timeout. Default: 60.
Examples:
>>> import rubrix as rb
>>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y")
"""

global _client

final_api_url = api_url or os.getenv("RUBRIX_API_URL", "http://localhost:6900")

# Checking that the api_url does not end in '/'
final_api_url = re.sub(r"\/$", "", final_api_url)
import sys as _sys
from typing import TYPE_CHECKING as _TYPE_CHECKING

# If an api_url is passed, tokens obtained via environ vars are disabled
final_key = api_key or os.getenv("RUBRIX_API_KEY", DEFAULT_API_KEY)
from rubrix.logging import configure_logging as _configure_logging

workspace = workspace or os.getenv("RUBRIX_WORKSPACE")
from . import _version
from .utils import _LazyRubrixModule

_LOGGER.info(f"Rubrix has been initialized on {final_api_url}")
__version__ = _version.version

_client = RubrixClient(
api_url=final_api_url,
api_key=final_key,
workspace=workspace,
timeout=timeout,
if _TYPE_CHECKING:
from rubrix.client.api import (
copy,
delete,
get_workspace,
init,
load,
log,
set_workspace,
)


def get_workspace() -> str:
"""Returns the name of the active workspace for the current client session.
Returns:
The name of the active workspace as a string.
"""
return _client_instance().active_workspace


def set_workspace(ws: str) -> None:
"""Sets the active workspace for the current client session.
Args:
ws: The new workspace
"""
_client_instance().set_workspace(ws)


def log(
records: Union[Record, Iterable[Record], Dataset],
name: str,
tags: Optional[Dict[str, str]] = None,
metadata: Optional[Dict[str, Any]] = None,
chunk_size: int = 500,
verbose: bool = True,
) -> BulkResponse:
"""Log Records to Rubrix.
Args:
records: The record or an iterable of records.
name: The dataset name.
tags: A dictionary of tags related to the dataset.
metadata: A dictionary of extra info for the dataset.
chunk_size: The chunk size for a data bulk.
verbose: If True, shows a progress bar and prints out a quick summary at the end.
Returns:
Summary of the response from the REST API
Examples:
>>> import rubrix as rb
>>> record = rb.TextClassificationRecord(
... inputs={"text": "my first rubrix example"},
... prediction=[('spam', 0.8), ('ham', 0.2)]
... )
>>> response = rb.log(record, name="example-dataset")
"""
# noinspection PyTypeChecker,PydanticTypeChecker
return _client_instance().log(
records=records,
name=name,
tags=tags,
metadata=metadata,
chunk_size=chunk_size,
verbose=verbose,
from rubrix.client.datasets import (
DatasetForText2Text,
DatasetForTextClassification,
DatasetForTokenClassification,
read_datasets,
read_pandas,
)


def copy(dataset: str, name_of_copy: str, workspace: str = None):
"""Creates a copy of a dataset including its tags and metadata
Args:
dataset: Name of the source dataset
name_of_copy: Name of the copied dataset
workspace: If provided, dataset will be copied to that workspace
Examples:
>>> import rubrix as rb
>>> rb.copy("my_dataset", name_of_copy="new_dataset")
>>> dataframe = rb.load("new_dataset")
"""
_client_instance().copy(
source=dataset, target=name_of_copy, target_workspace=workspace
)


def load(
name: str,
query: Optional[str] = None,
ids: Optional[List[Union[str, int]]] = None,
limit: Optional[int] = None,
as_pandas: bool = True,
) -> Union[pandas.DataFrame, Dataset]:
"""Loads a dataset as a pandas DataFrame or a Dataset.
Args:
name: The dataset name.
query: An ElasticSearch query with the
`query string syntax <https://rubrix.readthedocs.io/en/stable/reference/webapp/search_records.html>`_
ids: If provided, load dataset records with given ids.
limit: The number of records to retrieve.
as_pandas: If True, return a pandas DataFrame. If False, return a Dataset.
Returns:
The dataset as a pandas Dataframe or a Dataset.
Examples:
>>> import rubrix as rb
>>> dataframe = rb.load(name="example-dataset")
"""
return _client_instance().load(
name=name, query=query, limit=limit, ids=ids, as_pandas=as_pandas
from rubrix.client.models import (
Text2TextRecord,
TextClassificationRecord,
TokenAttributions,
TokenClassificationRecord,
)
from rubrix.monitoring.model_monitor import monitor
from rubrix.server.server import app

_import_structure = {
"client.api": [
"copy",
"delete",
"get_workspace",
"init",
"load",
"log",
"set_workspace",
],
"client.models": [
"Text2TextRecord",
"TextClassificationRecord",
"TokenClassificationRecord",
"TokenAttributions",
],
"client.datasets": [
"DatasetForText2Text",
"DatasetForTextClassification",
"DatasetForTokenClassification",
"read_datasets",
"read_pandas",
],
"monitoring.model_monitor": ["monitor"],
"server.app": ["app"],
}

# can be removed in a future version
_deprecated_import_structure = {
"client.models": ["Record", "BulkResponse"],
"client.datasets": ["Dataset"],
"client.rubrix_client": ["RubrixClient"],
"_constants": ["DEFAULT_API_KEY"],
}

_sys.modules[__name__] = _LazyRubrixModule(
__name__,
globals()["__file__"],
_import_structure,
deprecated_import_structure=_deprecated_import_structure,
module_spec=__spec__,
extra_objects={"__version__": __version__},
)

def delete(name: str) -> None:
"""Delete a dataset.
Args:
name: The dataset name.
Examples:
>>> import rubrix as rb
>>> rb.delete(name="example-dataset")
"""
_client_instance().delete(name=name)
_configure_logging()
4 changes: 2 additions & 2 deletions src/rubrix/client/api.py
Expand Up @@ -84,8 +84,8 @@ def __init__(
):
"""Init the Python client.
We will automatically init a default client for you when calling other client methods.
The arguments provided here will overwrite your corresponding environment variables.
Passing an api_url disables environment variable reading, which will provide
default values.
Args:
api_url: Address of the REST API. If `None` (default) and the env variable ``RUBRIX_API_URL`` is not set,
Expand Down
16 changes: 12 additions & 4 deletions src/rubrix/client/rubrix_client.py
Expand Up @@ -12,11 +12,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""The Rubrix client, used by the rubrix.__init__ module"""
"""
The Rubrix client, used by the rubrix.__init__ module.
DEPRECATED, CAN BE REMOVED IN A FUTURE VERSION. USE THE rubrix.client.api MODULE INSTEAD!
"""

import logging
import socket
import warnings
from typing import Any, Dict, Iterable, List, Optional, Union

import pandas
Expand Down Expand Up @@ -80,7 +83,7 @@ class InputValueError(RubrixClientError):


class RubrixClient:
"""Class definition for Rubrix Client"""
"""DEPRECATED. Class definition for Rubrix Client"""

_LOGGER = logging.getLogger(__name__)
_WARNED_ABOUT_AS_PANDAS = False
Expand All @@ -97,14 +100,19 @@ def __init__(
workspace: Optional[str] = None,
timeout: int = 60,
):
"""Client setup function.
"""DEPRECATED. Client setup function.
Args:
api_url: Address from which the API is serving.
api_key: Authentication token.
workspace: Active workspace for this client session.
timeout: Seconds to wait before raising a connection timeout.
"""
warnings.warn(
f"The 'RubrixClient' class is deprecated and will be removed in a future version! "
f"Use the `rubrix.client.api` module instead. Make sure to adapt your code.",
category=FutureWarning,
)

self._client = AuthenticatedClient(
base_url=api_url, token=api_key, timeout=timeout
Expand Down

0 comments on commit a0fdd8e

Please sign in to comment.