From 32fadaa30c00be2ce3c9a5fc1c995b46715b7d41 Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Mon, 25 Mar 2024 10:25:57 -0500 Subject: [PATCH] feat: add headless auth --- safety/auth/cli.py | 31 ++++++---- safety/auth/main.py | 5 +- safety/auth/server.py | 133 +++++++++++++++++++++++++--------------- tests/auth/test_cli.py | 2 +- tests/auth/test_main.py | 6 +- tests/test_cli.py | 48 +++++++++------ tests/test_safety.py | 2 + 7 files changed, 143 insertions(+), 84 deletions(-) diff --git a/safety/auth/cli.py b/safety/auth/cli.py index 6d07e2b..4d47222 100644 --- a/safety/auth/cli.py +++ b/safety/auth/cli.py @@ -96,7 +96,7 @@ def render_successful_login(auth: Auth, @auth_app.command(name=CMD_LOGIN_NAME, help=CLI_AUTH_LOGIN_HELP) -def login(ctx: typer.Context): +def login(ctx: typer.Context, headless: bool = False): """ Authenticate Safety CLI with your safetycli.com account using your default browser. """ @@ -105,22 +105,28 @@ def login(ctx: typer.Context): fail_if_authenticated(ctx, with_msg=MSG_FAIL_LOGIN_AUTHED) console.print() - brief_msg: str = "Redirecting your browser to log in; once authenticated, " \ - "return here to start using Safety" - uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, - code_verifier=ctx.obj.auth.code_verifier, - organization=ctx.obj.auth.org) + info = None - if ctx.obj.auth.org: + brief_msg: str = "Redirecting your browser to log in; once authenticated, " \ + "return here to start using Safety" + + if ctx.obj.auth.org: console.print(f"Logging into [bold]{ctx.obj.auth.org.name}[/bold] " \ "organization.") - + + if headless: + brief_msg = "Running in headless mode. Please copy and open the following URL in a browser" + + + uri, initial_state = get_authorization_data(client=ctx.obj.auth.client, + code_verifier=ctx.obj.auth.code_verifier, + organization=ctx.obj.auth.org, headless=headless) click.secho(brief_msg) click.echo() - info = process_browser_callback(uri, - initial_state=initial_state, ctx=ctx) + info = process_browser_callback(uri, initial_state=initial_state, ctx=ctx, headless=headless) + if info: if info.get("email", None): @@ -128,6 +134,9 @@ def login(ctx: typer.Context): if ctx.obj.auth.org and ctx.obj.auth.org.name: organization = ctx.obj.auth.org.name ctx.obj.auth.refresh_from(info) + if headless: + console.print() + render_successful_login(ctx.obj.auth, organization=organization) console.print() @@ -149,7 +158,7 @@ def login(ctx: typer.Context): else: msg += "Error logging into Safety." - msg += " Please try again, or use [bold]`safety auth –help`[/bold] " \ + msg += " Please try again, or use [bold]`safety auth -–help`[/bold] " \ "for more information[/red]" console.print(msg, emoji=True) diff --git a/safety/auth/main.py b/safety/auth/main.py index eadbb64..2eeb3ef 100644 --- a/safety/auth/main.py +++ b/safety/auth/main.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, Optional, Tuple, Union +from urllib.parse import urlencode from authlib.oidc.core import CodeIDToken from authlib.jose import jwt @@ -17,9 +18,9 @@ def get_authorization_data(client, code_verifier: str, organization: Optional[Organization] = None, - sign_up: bool = False, ensure_auth: bool = False) -> Tuple[str, str]: + sign_up: bool = False, ensure_auth: bool = False, headless: bool = False) -> Tuple[str, str]: - kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth} + kwargs = {'sign_up': sign_up, 'locale': 'en', 'ensure_auth': ensure_auth, 'headless': headless} if organization: kwargs['organization'] = organization.id diff --git a/safety/auth/server.py b/safety/auth/server.py index 45e8224..3559c6e 100644 --- a/safety/auth/server.py +++ b/safety/auth/server.py @@ -1,4 +1,5 @@ import http.server +import json import logging import socket import sys @@ -13,6 +14,8 @@ from safety.auth.constants import AUTH_SERVER_URL, CLI_AUTH_SUCCESS, CLI_LOGOUT_SUCCESS, HOST from safety.auth.main import save_auth_config +from authlib.integrations.base_client.errors import OAuthError +from rich.prompt import Prompt LOG = logging.getLogger(__name__) @@ -33,40 +36,49 @@ def find_available_port(): return None +def auth_process(code: str, state: str, initial_state: str, code_verifier, client): + err = None + + if initial_state is None or initial_state != state: + err = "The state parameter value provided does not match the expected " \ + "value. The state parameter is used to protect against Cross-Site " \ + "Request Forgery (CSRF) attacks. For security reasons, the " \ + "authorization process cannot proceed with an invalid state " \ + "parameter value. Please try again, ensuring that the state " \ + "parameter value provided in the authorization request matches " \ + "the value returned in the callback." + + if err: + click.secho(f'Error: {err}', fg='red') + sys.exit(1) + + try: + tokens = client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token', + code_verifier=code_verifier, + client_id=client.client_id, + grant_type='authorization_code', code=code) + + save_auth_config(access_token=tokens['access_token'], + id_token=tokens['id_token'], + refresh_token=tokens['refresh_token']) + return client.fetch_user_info() + + except Exception as e: + LOG.exception(e) + sys.exit(1) class CallbackHandler(http.server.BaseHTTPRequestHandler): def auth(self, code: str, state: str, err, error_description): initial_state = self.server.initial_state ctx = self.server.ctx - if initial_state is None or initial_state != state: - err = "The state parameter value provided does not match the expected" \ - "value. The state parameter is used to protect against Cross-Site " \ - "Request Forgery (CSRF) attacks. For security reasons, the " \ - "authorization process cannot proceed with an invalid state " \ - "parameter value. Please try again, ensuring that the state " \ - "parameter value provided in the authorization request matches " \ - "the value returned in the callback." - - if err: - click.secho(f'Error: {err}', fg='red') - sys.exit(1) + result = auth_process(code=code, + state=state, + initial_state=initial_state, + code_verifier=ctx.obj.auth.code_verifier, + client=ctx.obj.auth.client) - try: - tokens = ctx.obj.auth.client.fetch_token(url=f'{AUTH_SERVER_URL}/oauth/token', - code_verifier=ctx.obj.auth.code_verifier, - client_id=ctx.obj.auth.client.client_id, - grant_type='authorization_code', code=code) - - save_auth_config(access_token=tokens['access_token'], - id_token=tokens['id_token'], - refresh_token=tokens['refresh_token']) - self.server.callback = ctx.obj.auth.client.fetch_user_info() - - except Exception as e: - LOG.exception(e) - sys.exit(1) - + self.server.callback = result self.do_redirect(location=CLI_AUTH_SUCCESS, params={}) def logout(self): @@ -132,27 +144,52 @@ def handle_timeout(self) -> None: sys.exit(1) try: - server = ThreadedHTTPServer((HOST, PORT), CallbackHandler) - server.initial_state = kwargs.get("initial_state", None) - server.timeout = kwargs.get("timeout", 600) - # timeout = kwargs.get("timeout", None) - # timeout = float(timeout) if timeout else None - server.ctx = kwargs.get("ctx", None) - server_thread = threading.Thread(target=server.handle_request) - server_thread.start() - - target = f"{uri}&port={PORT}" - console.print(f"If the browser does not automatically open in 5 seconds, " \ - "copy and paste this url into your browser: " \ - f"[link={target}]{target}[/link]") - click.echo() - - wait_msg = "waiting for browser authentication" - - with console.status(wait_msg, spinner="bouncingBar"): - time.sleep(2) - click.launch(target) - server_thread.join() + headless = kwargs.get("headless", False) + initial_state = kwargs.get("initial_state", None) + ctx = kwargs.get("ctx", None) + + message = "Copy and paste this url into your browser:" + + + if not headless: + server = ThreadedHTTPServer((HOST, PORT), CallbackHandler) + server.initial_state = initial_state + server.timeout = kwargs.get("timeout", 600) + server.ctx = ctx + server_thread = threading.Thread(target=server.handle_request) + server_thread.start() + message = f"If the browser does not automatically open in 5 seconds, " \ + "copy and paste this url into your browser:" + + target = uri if headless else f"{uri}&port={PORT}" + console.print(f"{message} [link={target}]{target}[/link]") + console.print() + + if headless: + + exchange_data = None + while not exchange_data: + auth_code_text = Prompt.ask("Paste the response here", default=None, console=console) + try: + exchange_data = json.loads(auth_code_text) + state = exchange_data["state"] + code = exchange_data["code"] + except Exception as e: + code = state = None + + return auth_process(code=code, + state=state, + initial_state=initial_state, + code_verifier=ctx.obj.auth.code_verifier, + client=ctx.obj.auth.client) + else: + + wait_msg = "waiting for browser authentication" + + with console.status(wait_msg, spinner="bouncingBar"): + time.sleep(2) + click.launch(target) + server_thread.join() except OSError as e: if e.errno == socket.errno.EADDRINUSE: diff --git a/tests/auth/test_cli.py b/tests/auth/test_cli.py index b968cfe..9b035b3 100644 --- a/tests/auth/test_cli.py +++ b/tests/auth/test_cli.py @@ -28,7 +28,7 @@ def test_auth_calls_login(self, process_browser_callback, get_authorization_data.assert_called_once() process_browser_callback.assert_called_once_with(auth_data[0], initial_state=auth_data[1], - ctx=ANY) + ctx=ANY, headless=False) expected = [ "", diff --git a/tests/auth/test_main.py b/tests/auth/test_main.py index f634abc..e425ba7 100644 --- a/tests/auth/test_main.py +++ b/tests/auth/test_main.py @@ -30,7 +30,8 @@ def test_get_authorization_data(self): "sign_up": False, "locale": "en", "ensure_auth": False, - "organization": org_id + "organization": org_id, + "headless": False } client.create_authorization_url.assert_called_once_with( @@ -42,7 +43,8 @@ def test_get_authorization_data(self): kwargs = { "sign_up": False, "locale": "en", - "ensure_auth":False + "ensure_auth":False, + "headless": False } client.create_authorization_url.assert_called_once_with( diff --git a/tests/test_cli.py b/tests/test_cli.py index 54b02ef..8260118 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -204,8 +204,7 @@ def test_validate_with_basic_policy_file(self): result = self.runner.invoke(cli.cli, ['validate', 'policy_file', '3.0', '--path', path]) cleaned_stdout = click.unstyle(result.stdout) msg = 'The Safety policy (3.0) file (Used for scan and system-scan commands) was successfully parsed with the following values:\n' - parsed = json.dumps( - { + parsed = { "version": "3.0", "scan": { "max_depth": 6, @@ -230,19 +229,19 @@ def test_validate_with_basic_policy_file(self): }, "fail_scan": { "dependency_vulnerabilities": { - "enabled": True, - "fail_on_any_of": { - "cvss_severity": [ - "critical", - "high", - "medium" - ], - "exploitability": [ - "critical", - "high", - "medium" - ] - } + "enabled": True, + "fail_on_any_of": { + "cvss_severity": [ + "critical", + "high", + "medium", + ], + "exploitability": [ + "critical", + "high", + "medium", + ] + } } }, "security_updates": { @@ -252,12 +251,21 @@ def test_validate_with_basic_policy_file(self): ] } } - }, - indent=2 - ) + '\n' + } - self.assertEqual(msg + parsed, cleaned_stdout) - self.assertEqual(result.exit_code, 0) + msg_stdout, parsed_policy = cleaned_stdout.split('\n', 1) + msg_stdout += '\n' + parsed_policy = json.loads(parsed_policy.replace('\n', '')) + + fail_scan = parsed_policy.get("fail_scan", None) + self.assertIsNotNone(fail_scan) + fail_of_any = fail_scan["dependency_vulnerabilities"]["fail_on_any_of"] + fail_of_any["cvss_severity"] = sorted(fail_of_any["cvss_severity"]) + fail_of_any["exploitability"] = sorted(fail_of_any["exploitability"]) + + self.assertEqual(msg, msg_stdout) + self.assertEqual(parsed, parsed_policy) + self.assertEqual(result.exit_code, 0) def test_validate_with_policy_file_using_invalid_keyword(self): diff --git a/tests/test_safety.py b/tests/test_safety.py index 3fda4cf..ba0debc 100644 --- a/tests/test_safety.py +++ b/tests/test_safety.py @@ -494,6 +494,8 @@ def test_get_announcements_http_ok(self, get_used_options): @patch("safety.util.get_used_options") @patch.object(click, 'get_current_context', Mock(command=Mock(name=Mock(return_value='check')))) def test_get_announcements_wrong_json_response_handling(self, get_used_options): + get_used_options.return_value = {} + # wrong JSON structure announcements = { "type": "notice",