Skip to content

Commit

Permalink
Add GSSAPIAuthentication authentication class.
Browse files Browse the repository at this point in the history
  • Loading branch information
gkarg committed Feb 21, 2024
1 parent 00435d6 commit 5bc49e5
Show file tree
Hide file tree
Showing 5 changed files with 152 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
sudo apt-get update
sudo apt-get install libkrb5-dev
pip install wheel
pip install .[tests] sqlalchemy${{ matrix.sqlalchemy }}
pip install .[tests,gssapi] sqlalchemy${{ matrix.sqlalchemy }}
- name: Run tests
run: |
pytest -s tests/
37 changes: 37 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,43 @@ the [`Kerberos` authentication type](https://trino.io/docs/current/security/kerb
)
```

### GSSAPI authentication

The `GSSAPIAuthentication` class can be used to connect to a Trino cluster configured with
the [`Kerberos` authentication type](https://trino.io/docs/current/security/kerberos.html):

It follows the interface for `KerberosAuthentication`, but is using
[requests-gssapi](https://github.com/pythongssapi/requests-gssapi), instead of [requests-kerberos](https://github.com/requests/requests-kerberos) under the hood.

- DBAPI

```python
from trino.dbapi import connect
from trino.auth import GSSAPIAuthentication

conn = connect(
user="<username>",
auth=GSSAPIAuthentication(...),
http_scheme="https",
...
)
```

- SQLAlchemy

```python
from sqlalchemy import create_engine
from trino.auth import GSSAPIAuthentication

engine = create_engine(
"trino://<username>@<host>:<port>/<catalog>",
connect_args={
"auth": GSSAPIAuthentication(...),
"http_scheme": "https",
}
)
```

## User impersonation

In the case where user who submits the query is not the same as user who authenticates to Trino server (e.g in Superset),
Expand Down
2 changes: 2 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
readme = f.read()

kerberos_require = ["requests_kerberos"]
gssapi_require = ["requests_kerberos"]
sqlalchemy_require = ["sqlalchemy >= 1.3"]
external_authentication_token_cache_require = ["keyring"]

Expand Down Expand Up @@ -86,6 +87,7 @@
extras_require={
"all": all_require,
"kerberos": kerberos_require,
"gssapi": gssapi_require,
"sqlalchemy": sqlalchemy_require,
"tests": tests_require,
"external-authentication-token-cache": external_authentication_token_cache_require,
Expand Down
22 changes: 15 additions & 7 deletions tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest
import requests
from httpretty import httprettified
from requests_gssapi.exceptions import SPNEGOExchangeError
from requests_kerberos.exceptions import KerberosExchangeError
from tzlocal import get_localzone_name # type: ignore

Expand All @@ -39,7 +40,7 @@
_post_statement_requests,
)
from trino import __version__, constants
from trino.auth import KerberosAuthentication, _OAuth2TokenBearer
from trino.auth import GSSAPIAuthentication, KerberosAuthentication, _OAuth2TokenBearer
from trino.client import (
ClientSession,
TrinoQuery,
Expand Down Expand Up @@ -883,15 +884,22 @@ def retry_count(self):
return self._retry_count


def test_authentication_fail_retry(monkeypatch):
post_retry = RetryRecorder(error=KerberosExchangeError())
@pytest.mark.parametrize(
"auth_method, retry_exception",
[
(KerberosAuthentication, KerberosExchangeError),
(GSSAPIAuthentication, SPNEGOExchangeError),
]
)
def test_authentication_fail_retry(auth_class, retry_exception_class, monkeypatch):
post_retry = RetryRecorder(error=retry_exception_class())
monkeypatch.setattr(TrinoRequest.http.Session, "post", post_retry)

get_retry = RetryRecorder(error=KerberosExchangeError())
get_retry = RetryRecorder(error=retry_exception_class())
monkeypatch.setattr(TrinoRequest.http.Session, "get", get_retry)

attempts = 3
kerberos_auth = KerberosAuthentication()
kerberos_auth = auth_class()
req = TrinoRequest(
host="coordinator",
port=8080,
Expand All @@ -903,11 +911,11 @@ def test_authentication_fail_retry(monkeypatch):
max_attempts=attempts,
)

with pytest.raises(KerberosExchangeError):
with pytest.raises(retry_exception_class):
req.post("URL")
assert post_retry.retry_count == attempts

with pytest.raises(KerberosExchangeError):
with pytest.raises(retry_exception_class):
req.get("URL")
assert post_retry.retry_count == attempts

Expand Down
97 changes: 97 additions & 0 deletions trino/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,103 @@ def __eq__(self, other: object) -> bool:
and self._ca_bundle == other._ca_bundle)


class GSSAPIAuthentication(Authentication):
def __init__(
self,
config: Optional[str] = None,
service_name: Optional[str] = None,
mutual_authentication: bool = False,
force_preemptive: bool = False,
hostname_override: Optional[str] = None,
sanitize_mutual_error_response: bool = True,
principal: Optional[str] = None,
delegate: bool = False,
ca_bundle: Optional[str] = None,
) -> None:
self._config = config
self._service_name = service_name
self._mutual_authentication = mutual_authentication
self._force_preemptive = force_preemptive
self._hostname_override = hostname_override
self._sanitize_mutual_error_response = sanitize_mutual_error_response
self._principal = principal
self._delegate = delegate
self._ca_bundle = ca_bundle

def set_http_session(self, http_session: Session) -> Session:
try:
import requests_gssapi
except ImportError:
raise RuntimeError("unable to import requests_gssapi")

if self._config:
os.environ["KRB5_CONFIG"] = self._config
http_session.trust_env = False
http_session.auth = requests_gssapi.HTTPSPNEGOAuth(
mutual_authentication=self._mutual_authentication,
opportunistic_auth=self._force_preemptive,
target_name=self._get_target_name(self._hostname_override, self._service_name),
sanitize_mutual_error_response=self._sanitize_mutual_error_response,
creds=self._get_credentials(self._principal),
delegate=self._delegate,
)
if self._ca_bundle:
http_session.verify = self._ca_bundle
return http_session

def _get_credentials(self, principal: Optional[str] = None) -> Any:
if principal:
try:
import gssapi
except ImportError:
raise RuntimeError("unable to import gssapi")

name = gssapi.Name(principal, gssapi.NameType.user)
return gssapi.Credentials(name=name, usage="initiate")

return None

def _get_target_name(
self,
hostname_override: Optional[str] = None,
service_name: Optional[str] = None,
) -> Any:
if service_name is not None:
try:
import gssapi
except ImportError:
raise RuntimeError("unable to import gssapi")

if hostname_override is None:
raise ValueError("service name must be used together with hostname_override")

kerb_spn = "{0}@{1}".format(service_name, hostname_override)
return gssapi.Name(kerb_spn, gssapi.NameType.hostbased_service)

return hostname_override

def get_exceptions(self) -> Tuple[Any, ...]:
try:
from requests_gssapi.exceptions import SPNEGOExchangeError

return SPNEGOExchangeError,
except ImportError:
raise RuntimeError("unable to import requests_kerberos")

def __eq__(self, other: object) -> bool:
if not isinstance(other, GSSAPIAuthentication):
return False
return (self._config == other._config
and self._service_name == other._service_name
and self._mutual_authentication == other._mutual_authentication
and self._force_preemptive == other._force_preemptive
and self._hostname_override == other._hostname_override
and self._sanitize_mutual_error_response == other._sanitize_mutual_error_response
and self._principal == other._principal
and self._delegate == other._delegate
and self._ca_bundle == other._ca_bundle)


class BasicAuthentication(Authentication):
def __init__(self, username: str, password: str):
self._username = username
Expand Down

0 comments on commit 5bc49e5

Please sign in to comment.