Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Client): expose client extra headers in init function #1715

Merged
merged 4 commits into from Sep 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/rubrix/client/api.py
Expand Up @@ -113,6 +113,7 @@ def __init__(
api_key: Optional[str] = None,
workspace: Optional[str] = None,
timeout: int = 60,
extra_headers: Optional[Dict[str, str]] = None,
):
"""Init the Python client.

Expand All @@ -127,22 +128,32 @@ def __init__(
workspace: The workspace to which records will be logged/loaded. If `None` (default) and the
env variable ``RUBRIX_WORKSPACE`` is not set, it will default to the private user workspace.
timeout: Wait `timeout` seconds for the connection to timeout. Default: 60.
extra_headers: Extra HTTP headers sent to the server. You can use this to customize
the headers of Rubrix client requests, like additional security restrictions. Default: `None`.

Examples:
>>> import rubrix as rb
>>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y")
>>> # Customizing request headers
>>> headers = {"X-Client-id":"id","X-Secret":"secret"}
>>> rb.init(api_url="http://localhost:9090", api_key="4AkeAPIk3Y", extra_headers=headers)

"""
api_url = api_url or os.getenv("RUBRIX_API_URL", "http://localhost:6900")
# Checking that the api_url does not end in '/'
api_url = re.sub(r"\/$", "", api_url)
api_key = api_key or os.getenv("RUBRIX_API_KEY", DEFAULT_API_KEY)
workspace = workspace or os.getenv("RUBRIX_WORKSPACE")
headers = extra_headers or {}

self._client: AuthenticatedClient = AuthenticatedClient(
base_url=api_url, token=api_key, timeout=timeout
base_url=api_url,
token=api_key,
timeout=timeout,
headers=headers.copy(),
)
self._user: User = users_api.whoami(client=self._client)

self._user: User = users_api.whoami(client=self._client)
if workspace is not None:
self.set_workspace(workspace)

Expand Down
18 changes: 16 additions & 2 deletions tests/client/test_init.py
@@ -1,8 +1,8 @@
from rubrix.client import api


def test_resource_leaking_with_several_inits(mocked_client):
dataset = "test_resource_leaking_with_several_inits"
def test_resource_leaking_with_several_init(mocked_client):
dataset = "test_resource_leaking_with_several_init"
api.delete(dataset)

# TODO: review performance in Windows. See https://github.com/recognai/rubrix/pull/1702
Expand All @@ -16,3 +16,17 @@ def test_resource_leaking_with_several_inits(mocked_client):
)

assert len(api.load(dataset)) == 10


def test_init_with_extra_headers(mocked_client):
expected_headers = {
"X-Custom-Header": "Mocking rules!",
"Other-header": "Header value",
}
api.init(extra_headers=expected_headers)
active_api = api.active_api()

for key, value in expected_headers.items():
assert (
active_api.client.headers[key] == value
), f"{key}:{value} not in client headers"