diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1f84d13e..683bc29d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -14,8 +14,10 @@ import pytest +from tests.unit.oauth_test_utils import SERVER_ADDRESS -@pytest.fixture(scope="session") + +@pytest.fixture def sample_post_response_data(): """ This is the response to the first HTTP request (a POST) from an actual @@ -38,10 +40,10 @@ def sample_post_response_data(): """ yield { - "nextUri": "https://coordinator:8080/v1/statement/20210817_140827_00000_arvdv/1", + "nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/1", "id": "20210817_140827_00000_arvdv", "taskDownloadUris": [], - "infoUri": "https://coordinator:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", "stats": { "scheduled": False, "runningSplits": 0, @@ -60,7 +62,7 @@ def sample_post_response_data(): } -@pytest.fixture(scope="session") +@pytest.fixture def sample_get_response_data(): """ This is the response to the second HTTP request (a GET) from an actual @@ -73,7 +75,7 @@ def sample_get_response_data(): """ yield { "id": "20210817_140827_00000_arvdv", - "nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/2", + "nextUri": f"{SERVER_ADDRESS}:8080/v1/statement/20210817_140827_00000_arvdv/2", "data": [ ["UUID-0", "http://worker0:8080", "0.157", False, "active"], ["UUID-1", "http://worker1:8080", "0.157", False, "active"], @@ -132,7 +134,7 @@ def sample_get_response_data(): }, ], "taskDownloadUris": [], - "partialCancelUri": "http://localhost:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 + "partialCancelUri": f"{SERVER_ADDRESS}:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 "stats": { "nodes": 2, "processedBytes": 880, @@ -181,11 +183,11 @@ def sample_get_response_data(): "queuedSplits": 0, "wallTimeMillis": 36, }, - "infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 } -@pytest.fixture(scope="session") +@pytest.fixture def sample_get_error_response_data(): yield { "error": { @@ -195,8 +197,7 @@ def sample_get_error_response_data(): "errorType": "USER_ERROR", "failureInfo": { "errorLocation": {"columnNumber": 15, "lineNumber": 1}, - "message": "line 1:15: Schema must be specified " - "when session schema is not set", + "message": "line 1:15: Schema must be specified when session schema is not set", "stack": [ "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:48)", "io.trino.sql.analyzer.SemanticExceptions.semanticException(SemanticExceptions.java:43)", @@ -241,7 +242,7 @@ def sample_get_error_response_data(): "message": "line 1:15: Schema must be specified when session schema is not set", }, "id": "20210817_140827_00000_arvdv", - "infoUri": "http://localhost:8080/query.html?20210817_140827_00000_arvdv", + "infoUri": f"{SERVER_ADDRESS}:8080/query.html?20210817_140827_00000_arvdv", "stats": { "completedSplits": 0, "cpuTimeMillis": 0, diff --git a/tests/unit/test_dbapi.py b/tests/unit/test_dbapi.py index b56466a2..06f794c4 100644 --- a/tests/unit/test_dbapi.py +++ b/tests/unit/test_dbapi.py @@ -14,7 +14,6 @@ from unittest.mock import patch import httpretty -from httpretty import httprettified from requests import Session from tests.unit.oauth_test_utils import ( @@ -58,7 +57,7 @@ def test_http_session_is_defaulted_when_not_specified(mock_client): assert mock_client.TrinoRequest.http.Session.return_value in request_args -@httprettified +@httpretty.activate def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) @@ -73,13 +72,15 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl httpretty.register_uri( method=httpretty.POST, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}", - body=post_statement_callback) + body=post_statement_callback + ) # bind get statement for result retrieval httpretty.register_uri( method=httpretty.GET, uri=f"{SERVER_ADDRESS}:8080{constants.URL_STATEMENT_PATH}/20210817_140827_00000_arvdv/1", - body=get_statement_callback) + body=get_statement_callback + ) # bind get token get_token_callback = GetTokenCallback(token_server, token) @@ -122,7 +123,7 @@ def test_token_retrieved_once_per_auth_instance(sample_post_response_data, sampl assert len(_get_token_requests(challenge_id)) == 2 -@httprettified +@httpretty.activate def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) @@ -188,7 +189,7 @@ def test_token_retrieved_once_when_authentication_instance_is_shared(sample_post assert len(_get_token_requests(challenge_id)) == 1 -@httprettified +@httpretty.activate def test_token_retrieved_once_when_multithreaded(sample_post_response_data, sample_get_response_data): token = str(uuid.uuid4()) challenge_id = str(uuid.uuid4()) diff --git a/trino/auth.py b/trino/auth.py index dc7b577a..d36705bd 100644 --- a/trino/auth.py +++ b/trino/auth.py @@ -395,7 +395,7 @@ def _determine_host(url: Optional[str]) -> Any: class OAuth2Authentication(Authentication): - def __init__(self, redirect_auth_url_handler: CompositeRedirectHandler = CompositeRedirectHandler([ + def __init__(self, redirect_auth_url_handler: RedirectHandler = CompositeRedirectHandler([ WebBrowserRedirectHandler(), ConsoleRedirectHandler() ])):