Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Secure client with grpc auth #17

Merged
merged 1 commit into from
May 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ jobs:
run: |
source $VENV
pre-commit run --color=always --show-diff-on-failure
- name: Run tests with coverage
run: |
source $VENV
poetry run coverage run -m unittest discover -s tests -p "test_*.py"
poetry run coverage report -m

build:
needs: lint
strategy:
Expand All @@ -60,12 +66,16 @@ jobs:
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install dependencies if cache hit
- name: Install dependencies if cache miss
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install library
- name: Install dependencies if poetry.lock changed
run: poetry install --no-interaction
- name: Compile proto to generate API stubs
run: |
source $VENV
poetry run make generate
- name: Run tests
run: |
source $VENV
poetry run python -m unittest discover -s tests -p "test_*.py"
3 changes: 3 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ repos:
hooks:
- id: flake8
args: [--max-line-length=88]
additional_dependencies: [
pep8-naming==0.13.3
]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
Expand Down
275 changes: 177 additions & 98 deletions poetry.lock

Large diffs are not rendered by default.

7 changes: 7 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,17 @@ optional = true

[tool.poetry.group.dev.dependencies]
pre-commit = "^3.3.0"
coverage = {version = "^7.2.5", extras = ["toml"]}

[tool.poetry.scripts]
make = "scripts.proto:main"

[tool.coverage.run]
source = ["tigrisdb"]

[tool.coverage.report]
fail_under = 35

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
1 change: 0 additions & 1 deletion reviewpad.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ workflows:
- rule: $isDraft() == false && $rule("is-main-branch")
then:
- $titleLint()
- $fail("PR title should match 'conventional commit' format")

- name: license-validation
description: Validate that licenses are not modified
Expand Down
4 changes: 3 additions & 1 deletion scripts/proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def generate():
if pd_file.endswith(".proto"):
proto_sources.append(os.path.join(pd_path, pd_file))

for pf in ["api.proto", "search.proto"]:
for pf in ["api.proto", "search.proto", "auth.proto"]:
pf_path = os.path.join(TIGRIS_PROTO_DIR, pf)
proto_sources.append(pf_path)

Expand Down Expand Up @@ -68,6 +68,8 @@ def generate():
with open(fp, "w") as f:
f.write(fdata)

print(f"SUCCESS! Compiled proto stubs available in:\n{GENERATED_PROTO_DIR}")


def clean():
if os.path.exists(GENERATED_PROTO_DIR):
Expand Down
15 changes: 15 additions & 0 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from typing import Optional

import grpc


class StubRpcError(grpc.RpcError):
def __init__(self, code: str, details: Optional[str]):
self._code = code
self._details = details

def code(self):
return self._code

def details(self):
return self._details
92 changes: 92 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import time
from unittest import TestCase
from unittest.mock import MagicMock, patch

import grpc

from api.generated.server.v1.auth_pb2 import GetAccessTokenResponse
from tests import StubRpcError
from tigrisdb.auth import AuthGateway
from tigrisdb.errors import TigrisServerError
from tigrisdb.types import ClientConfig
from tigrisdb.utils import dict_to_b64


@patch("api.generated.server.v1.auth_pb2_grpc.AuthStub")
@patch("grpc.channel_ready_future")
class AuthGatewayTest(TestCase):
def setUp(self) -> None:
self.done_future = MagicMock(grpc.Future)
self.client_config = ClientConfig(
server_url="localhost:5000", project_name="db1"
)

def test_get_access_token_with_valid_token_refresh_window(
self, channel_ready_future, grpc_auth
):
channel_ready_future.return_value = self.done_future
expiration_time = time.time()
mock_grpc_auth, expected_token = grpc_auth(), _encoded_token(expiration_time)
mock_grpc_auth.GetAccessToken.return_value = GetAccessTokenResponse(
access_token=expected_token
)

auth_gateway = AuthGateway(self.client_config)
actual_token = auth_gateway.get_access_token()
self.assertEqual(expected_token, actual_token)
next_refresh = auth_gateway.__getattribute__("_AuthGateway__next_refresh_time")
# refresh time is within 11 minutes of expiration time
self.assertLessEqual(expiration_time - next_refresh, 660)

def test_get_access_token_with_rpc_failure(self, channel_ready_future, grpc_auth):
channel_ready_future.return_value = self.done_future
mock_grpc_auth = grpc_auth()
mock_grpc_auth.GetAccessToken.side_effect = StubRpcError(
code="Unavailable", details=""
)

auth_gateway = AuthGateway(self.client_config)
with self.assertRaisesRegex(
TigrisServerError, "failed to get access token"
) as e:
auth_gateway.get_access_token()
self.assertIsNotNone(e)

def test_should_refresh_with_expired_token(self, channel_ready_future, grpc_auth):
channel_ready_future.return_value = self.done_future
auth_gateway = AuthGateway(self.client_config)

self.assertTrue(auth_gateway.should_refresh())
auth_gateway.__setattr__("_AuthGateway__cached_token", "xyz")
auth_gateway.__setattr__("_AuthGateway__next_refresh_time", time.time() + 5)
self.assertFalse(auth_gateway.should_refresh())
auth_gateway.__setattr__("_AuthGateway__next_refresh_time", time.time() - 5)
self.assertTrue(auth_gateway.should_refresh())

def test_should_refresh_without_cached_token(self, channel_ready_future, grpc_auth):
channel_ready_future.return_value = self.done_future
auth_gateway = AuthGateway(self.client_config)

self.assertTrue(auth_gateway.should_refresh())
auth_gateway.__setattr__("_AuthGateway__cached_token", "xyz")
self.assertTrue(auth_gateway.should_refresh())
auth_gateway.__setattr__("_AuthGateway__next_refresh_time", time.time() + 10)
self.assertFalse(auth_gateway.should_refresh())

def test_get_auth_headers(self, channel_ready_future, grpc_auth):
channel_ready_future.return_value = self.done_future
auth_gateway = AuthGateway(self.client_config)

auth_gateway.__setattr__("_AuthGateway__cached_token", "xyz")
auth_gateway.__setattr__("_AuthGateway__next_refresh_time", time.time() + 10)
self.assertFalse(auth_gateway.should_refresh())
expected_headers = [
("authorization", "Bearer xyz"),
("user-agent", "tigris-client-python.grpc"),
("destination-name", self.client_config.server_url),
]
self.assertCountEqual(expected_headers, auth_gateway.get_auth_headers(None))


def _encoded_token(expiration: float):
return f'token.{dict_to_b64({"exp": expiration})}'
54 changes: 0 additions & 54 deletions tests/test_documents.py

This file was deleted.

72 changes: 72 additions & 0 deletions tigrisdb/auth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from random import randint
from time import time

import grpc

from api.generated.server.v1 import auth_pb2_grpc as tigris_auth
from api.generated.server.v1.auth_pb2 import (
CLIENT_CREDENTIALS,
GetAccessTokenRequest,
GetAccessTokenResponse,
)
from tigrisdb.errors import TigrisException, TigrisServerError
from tigrisdb.types import ClientConfig
from tigrisdb.utils import b64_to_dict


class AuthGateway(grpc.AuthMetadataPlugin):
__cached_token: str = ""
__next_refresh_time: float = 0.0
__config: ClientConfig
__auth_stub: tigris_auth.AuthStub

def __init__(self, config: ClientConfig):
super(grpc.AuthMetadataPlugin, self).__init__()
self.__config = config
channel = grpc.secure_channel(config.server_url, grpc.ssl_channel_credentials())
try:
grpc.channel_ready_future(channel).result(timeout=10)
except grpc.FutureTimeoutError:
raise TigrisException(f"Auth connection timed out: {config.server_url}")
self.__auth_stub = tigris_auth.AuthStub(channel)

def get_access_token(self):
if self.should_refresh():
req = GetAccessTokenRequest(
grant_type=CLIENT_CREDENTIALS,
client_id=self.__config.client_id,
client_secret=self.__config.client_secret,
)
try:
resp: GetAccessTokenResponse = self.__auth_stub.GetAccessToken(req)
self.__cached_token = resp.access_token
token_meta = b64_to_dict(self.__cached_token.split(".")[1] + "==")
exp = float(token_meta["exp"])
self.__next_refresh_time = exp - 300 - float(randint(0, 300) + 60)
except grpc.RpcError as e:
raise TigrisServerError("failed to get access token", e)
return self.__cached_token

def should_refresh(self):
return (not self.__cached_token) or time() >= self.__next_refresh_time

def get_auth_headers(self, context: grpc.AuthMetadataContext):
headers = (
("authorization", f"Bearer {self.get_access_token()}"),
("user-agent", "tigris-client-python.grpc"),
("destination-name", self.__config.server_url),
)
return headers

def __call__(self, context, callback):
"""Implements authentication by passing metadata to a callback.

This method will be invoked asynchronously in a separate thread.

Args:
context: An AuthMetadataContext providing information on the RPC that
the plugin is being called to authenticate.
callback: An AuthMetadataPluginCallback to be invoked either
synchronously or asynchronously.
"""
callback(self.get_auth_headers(context), None)
37 changes: 24 additions & 13 deletions tigrisdb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,55 @@
import grpc

from api.generated.server.v1 import api_pb2_grpc as tigris_grpc
from tigrisdb.config import TigrisClientConfig
from tigrisdb.auth import AuthGateway
from tigrisdb.database import Database
from tigrisdb.errors import TigrisException
from tigrisdb.types import ClientConfig


class TigrisClient(object):
__LOCAL_SERVER = "localhost:8081"

__tigris_stub: tigris_grpc.TigrisStub
__config: TigrisClientConfig
__config: ClientConfig

def __init__(self, config: Optional[TigrisClientConfig]):
def __init__(self, config: Optional[ClientConfig]):
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"

if not config:
config = TigrisClientConfig()
config = ClientConfig()
self.__config = config
if not config.server_url:
config.server_url = TigrisClient.__LOCAL_SERVER
if config.server_url.startswith("https://"):
config.server_url.replace("https://", "")

is_local_dev = filter(
lambda k: k in config.server_url,
["localhost", "127.0.0.1", "tigrisdb-local-server:", "[::1]"],
config.server_url = config.server_url.replace("https://", "")
if config.server_url.startswith("http://"):
config.server_url = config.server_url.replace("http://", "")
if ":" not in config.server_url:
config.server_url = f"{config.server_url}:443"

is_local_dev = any(
map(
lambda k: k in config.server_url,
["localhost", "127.0.0.1", "tigrisdb-local-server:", "[::1]"],
)
)

if is_local_dev:
channel = grpc.insecure_channel(config.server_url)
else:
raise NotImplementedError(
"Secure channels will be supported in upcoming versions"
auth_gtwy = AuthGateway(config)
channel_creds = grpc.ssl_channel_credentials()
call_creds = grpc.metadata_call_credentials(auth_gtwy, name="auth gateway")
channel = grpc.secure_channel(
config.server_url,
grpc.composite_channel_credentials(channel_creds, call_creds),
)

try:
grpc.channel_ready_future(channel).result(timeout=10)
except grpc.RpcError:
raise TigrisException(f"Connection timed out: {config.server_url}")
except grpc.FutureTimeoutError:
raise TigrisException(f"Connection timed out {config.server_url}")

self.__tigris_stub = tigris_grpc.TigrisStub(channel)

Expand Down
Loading