From 4b028d97105f0a40c0ad18da8162712d59a14381 Mon Sep 17 00:00:00 2001 From: Yeison Vargas Date: Tue, 30 Apr 2024 17:34:17 -0500 Subject: [PATCH] feat: add SAFETY_DB_DIR env var to the scan command --- safety/auth/cli_utils.py | 2 +- safety/cli.py | 11 ++++++++++- safety/safety.py | 18 ++++++++++-------- safety/scan/finder/handlers.py | 10 ++++++++-- 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/safety/auth/cli_utils.py b/safety/auth/cli_utils.py index 4cffe1b..527f369 100644 --- a/safety/auth/cli_utils.py +++ b/safety/auth/cli_utils.py @@ -54,7 +54,7 @@ def update_token(tokens, **kwargs): try: openid_config = client_session.get(url=OPENID_CONFIG_URL, timeout=REQUEST_TIMEOUT).json() except Exception as e: - LOG.exception('Unable to load the openID config: %s', e) + LOG.debug('Unable to load the openID config: %s', e) openid_config = {} client_session.metadata["token_endpoint"] = openid_config.get("token_endpoint", diff --git a/safety/cli.py b/safety/cli.py index 3e8bcad..6edef4d 100644 --- a/safety/cli.py +++ b/safety/cli.py @@ -49,11 +49,20 @@ LOG = logging.getLogger(__name__) + +def configure_logger(ctx, param, debug): + level = logging.CRITICAL + + if debug: + level = logging.DEBUG + + logging.basicConfig(format='%(asctime)s %(name)s => %(message)s', level=level) + @click.group(cls=SafetyCLILegacyGroup, help=CLI_MAIN_INTRODUCTION, epilog=DEFAULT_EPILOG) @auth_options() @proxy_options @click.option('--disable-optional-telemetry', default=False, is_flag=True, show_default=True, help=CLI_DISABLE_OPTIONAL_TELEMETRY_DATA_HELP) -@click.option('--debug', default=False, help=CLI_DEBUG_HELP) +@click.option('--debug', default=False, help=CLI_DEBUG_HELP, callback=configure_logger) @click.version_option(version=get_safety_version()) @click.pass_context @inject_session diff --git a/safety/safety.py b/safety/safety.py index 86aa172..741588e 100644 --- a/safety/safety.py +++ b/safety/safety.py @@ -199,10 +199,13 @@ def post_results(session, safety_json, policy_file): return {} -def fetch_database_file(path, db_name, ecosystem: Ecosystem = Ecosystem.PYTHON): - full_path = os.path.join(path, db_name) - if not os.path.exists(full_path): +def fetch_database_file(path: str, db_name: str, + ecosystem: Optional[Ecosystem] = None): + full_path = (Path(path) / (ecosystem.value if ecosystem else '') / db_name).expanduser().resolve() + + if not full_path.exists(): raise DatabaseFileNotFoundError(db=path) + with open(full_path) as f: return json.loads(f.read()) @@ -220,9 +223,7 @@ def is_valid_database(db) -> bool: def fetch_database(session, full=False, db=False, cached=0, telemetry=True, ecosystem: Ecosystem = Ecosystem.PYTHON, from_cache=True): if session.is_using_auth_credentials(): - mirrors = API_MIRRORS - elif db: - mirrors = [db] + mirrors = [db] if db else API_MIRRORS else: mirrors = OPEN_MIRRORS @@ -233,7 +234,8 @@ def fetch_database(session, full=False, db=False, cached=0, telemetry=True, data = fetch_database_url(session, mirror, db_name=db_name, cached=cached, telemetry=telemetry, ecosystem=ecosystem, from_cache=from_cache) else: - data = fetch_database_file(mirror, db_name=db_name, ecosystem=ecosystem) + data = fetch_database_file(mirror, db_name=db_name, + ecosystem=ecosystem) if data: if is_valid_database(data): return data @@ -1000,7 +1002,7 @@ def get_licenses(*, session=None, db_mirror=False, cached=0, telemetry=True): licenses = fetch_database_url(session, mirror, db_name=db_name, cached=cached, telemetry=telemetry) else: - licenses = fetch_database_file(mirror, db_name=db_name) + licenses = fetch_database_file(mirror, db_name=db_name, ecosystem=None) if licenses: return licenses raise DatabaseFetchError() diff --git a/safety/scan/finder/handlers.py b/safety/scan/finder/handlers.py index 395d9e0..4e2f696 100644 --- a/safety/scan/finder/handlers.py +++ b/safety/scan/finder/handlers.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +import os from pathlib import Path from types import MappingProxyType from typing import Dict, List, Optional, Tuple @@ -49,12 +50,17 @@ def __init__(self) -> None: def download_required_assets(self, session): from safety.safety import fetch_database + + SAFETY_DB_DIR = os.getenv("SAFETY_DB_DIR") + + db = False if SAFETY_DB_DIR is None else SAFETY_DB_DIR + - fetch_database(session=session, full=False, db=False, cached=True, + fetch_database(session=session, full=False, db=db, cached=True, telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False) - fetch_database(session=session, full=True, db=False, cached=True, + fetch_database(session=session, full=True, db=db, cached=True, telemetry=True, ecosystem=Ecosystem.PYTHON, from_cache=False)