Skip to content

Commit

Permalink
feat: Add timeout for OAuthCallbackHttpServer#wait_for_code to preven…
Browse files Browse the repository at this point in the history
…t CLI hangs
  • Loading branch information
timo-reymann committed Dec 10, 2023
1 parent 20d6000 commit 603ce83
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 3 deletions.
19 changes: 16 additions & 3 deletions oauth2_cli_auth/http_server.py
Expand Up @@ -2,6 +2,7 @@
from string import Template
from typing import Optional
from urllib.parse import parse_qs, urlparse
from oauth2_cli_auth._timeout import _method_with_timeout, TimeoutException


class CallbackPageTemplate:
Expand Down Expand Up @@ -153,7 +154,6 @@ class OAuthRedirectHandler(BaseHTTPRequestHandler):
def do_GET(self):
params = parse_qs(urlparse(self.path).query)


has_error = "code" not in params or len(params['code']) != 1 or params['code'][0].strip() == ""

if has_error:
Expand Down Expand Up @@ -197,9 +197,22 @@ def get_code(self):
def callback_url(self):
return f"http://localhost:{self.server_port}"

def wait_for_code(self, attempts: int = 3) -> Optional[int]:
def wait_for_code(self, attempts: int = 3, timeout_per_attempt=10) -> Optional[int]:
"""
Wait for the server to open the callback page containing the code query parameter.
It tries for #attempts with a timeout of #timeout_per_attempts for each attempt.
This prevents the CLI from getting stuck by unsolved callback URls
:param attempts: Amount of attempts
:param timeout_per_attempt: Timeout for each attempt to be successful
:return: Code from callback page or None if the callback page is not called successfully
"""
for i in range(0, attempts):
self.handle_request()
try:
_method_with_timeout(self.handle_request, timeout_seconds=timeout_per_attempt)
except TimeoutException:
continue
if self.get_code() is not None:
return self.get_code()

Expand Down
1 change: 1 addition & 0 deletions oauth2_cli_auth/http_server_test.py
Expand Up @@ -13,6 +13,7 @@ def test_http_server_ok():
threading.Thread(target=server.handle_request).start()
with urllib.request.urlopen("http://localhost:5000?code=foo") as response:
content = response.read().decode("utf-8")
assert content is not None


def test_http_server_bad_request():
Expand Down

0 comments on commit 603ce83

Please sign in to comment.