diff --git a/src/codegate/ca/codegate_ca.py b/src/codegate/ca/codegate_ca.py index 1ac6d128..f534e82e 100644 --- a/src/codegate/ca/codegate_ca.py +++ b/src/codegate/ca/codegate_ca.py @@ -368,35 +368,6 @@ def generate_certificates(self) -> Tuple[str, str]: ) # Print instructions for trusting the certificates - logger.info( - """ -Certificates generated successfully in the 'certs' directory -To trust these certificates: - -On macOS: -`sudo security add-trusted-cert -d -r trustRoot -k /Library/Keychains/System.keychain certs/ca.crt` - -On Windows (PowerShell as Admin): -`Import-Certificate -FilePath "certs\\ca.crt" -CertStoreLocation Cert:\\LocalMachine\\Root` - -On Linux: -`sudo cp certs/ca.crt /usr/local/share/ca-certificates/codegate.crt` -`sudo update-ca-certificates` - -For VSCode, add to settings.json: -{ - "http.proxy": "https://localhost:8990", - "http.proxyStrictSSL": true, - "http.proxySupport": "on", - "github.copilot.advanced": { - "debug.useNodeFetcher": true, - "debug.useElectronFetcher": true, - "debug.testOverrideProxyUrl": "https://localhost:8990", - "debug.overrideProxyUrl": "https://localhost:8990" - }, -} -""" - ) logger.debug("Certificates generated successfully") return server_cert, server_key @@ -422,23 +393,21 @@ def create_ssl_context(self) -> ssl.SSLContext: logger.debug("SSL context created successfully") return ssl_context - def ensure_certificates_exist(self) -> None: + def check_certificates_exist(self) -> bool: + """Check if SSL certificates exist""" + logger.debug("Checking if certificates exist fn: check_certificates_exist") + return os.path.exists( + os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert) + ) and os.path.exists( + os.path.join(Config.get_config().certs_dir, Config.get_config().server_key) + ) + + def ensure_certificates_exist(self) -> bool: """Ensure SSL certificates exist, generate if they don't""" logger.debug("Ensuring certificates exist. fn ensure_certificates_exist") - if not ( - os.path.exists( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_cert) - ) - and os.path.exists( - os.path.join(Config.get_config().certs_dir, Config.get_config().server_key) - ) - ): - logger.debug("Certificates not found, generating new certificates") + if not self.check_certificates_exist(): + logger.info("Certificates not found. Generating new certificates.") self.generate_certificates() - else: - server_cert = Config.get_config().server_cert - server_key = Config.get_config().server_key - logger.debug(f"Certificates found at: {server_cert} and {server_key}.") def get_ssl_context(self) -> ssl.SSLContext: """Get SSL context with certificates""" diff --git a/src/codegate/cli.py b/src/codegate/cli.py index 4c4ce870..b230eba9 100644 --- a/src/codegate/cli.py +++ b/src/codegate/cli.py @@ -448,6 +448,12 @@ def restore_backup(backup_path: Path, backup_name: str) -> None: default=None, help="Name that will be given to the created server-key.", ) +@click.option( + "--force-certs", + is_flag=True, + default=False, + help="Force the generation of certificates even if they already exist.", +) @click.option( "--log-level", type=click.Choice([level.value for level in LogLevel]), @@ -466,6 +472,7 @@ def generate_certs( ca_key_name: Optional[str], server_cert_name: Optional[str], server_key_name: Optional[str], + force_certs: bool, log_level: Optional[str], log_format: Optional[str], ) -> None: @@ -476,12 +483,22 @@ def generate_certs( ca_key=ca_key_name, server_cert=server_cert_name, server_key=server_key_name, + force_certs=force_certs, cli_log_level=log_level, cli_log_format=log_format, ) setup_logging(cfg.log_level, cfg.log_format) + ca = CertificateAuthority.get_instance() - ca.generate_certificates() + should_generate = force_certs or not ca.check_certificates_exist() + + if should_generate: + ca.generate_certificates() + click.echo("Certificates generated successfully.") + click.echo(f"Certificates saved to: {cfg.certs_dir}") + click.echo("Make sure to add the new CA certificate to the operating system trust store.") + else: + click.echo("Certificates already exist. Skipping generation...") def main() -> None: diff --git a/src/codegate/config.py b/src/codegate/config.py index 45c634b7..be7b0143 100644 --- a/src/codegate/config.py +++ b/src/codegate/config.py @@ -50,6 +50,7 @@ class Config: ca_key: str = "ca.key" server_cert: str = "server.crt" server_key: str = "server.key" + force_certs: bool = False # Provider URLs with defaults provider_urls: Dict[str, str] = field(default_factory=lambda: DEFAULT_PROVIDER_URLS.copy()) @@ -142,6 +143,7 @@ def from_file(cls, config_path: Union[str, Path]) -> "Config": ca_key=config_data.get("ca_key", cls.ca_key), server_cert=config_data.get("server_cert", cls.server_cert), server_key=config_data.get("server_key", cls.server_key), + force_certs=config_data.get("force_certs", cls.force_certs), prompts=prompts_config, provider_urls=provider_urls, ) @@ -187,6 +189,8 @@ def from_env(cls) -> "Config": config.server_cert = os.environ["CODEGATE_SERVER_CERT"] if "CODEGATE_SERVER_KEY" in os.environ: config.server_key = os.environ["CODEGATE_SERVER_KEY"] + if "CODEGATE_FORCE_CERTS" in os.environ: + config.force_certs = os.environ["CODEGATE_FORCE_CERTS"] # Load provider URLs from environment variables for provider in DEFAULT_PROVIDER_URLS.keys(): @@ -216,6 +220,7 @@ def load( ca_key: Optional[str] = None, server_cert: Optional[str] = None, server_key: Optional[str] = None, + force_certs: Optional[bool] = None, db_path: Optional[str] = None, ) -> "Config": """Load configuration with priority resolution. @@ -242,6 +247,7 @@ def load( ca_key: Optional path to CA key server_cert: Optional path to server certificate server_key: Optional path to server key + force_certs: Optional flag to force certificate generation db_path: Optional path to the SQLite database file Returns: @@ -289,6 +295,8 @@ def load( config.server_cert = env_config.server_cert if "CODEGATE_SERVER_KEY" in os.environ: config.server_key = env_config.server_key + if "CODEGATE_FORCE_CERTS" in os.environ: + config.force_certs = env_config.force_certs # Override provider URLs from environment for provider, url in env_config.provider_urls.items(): @@ -325,16 +333,8 @@ def load( config.server_key = server_key if db_path is not None: config.db_path = db_path - if certs_dir is not None: - config.certs_dir = certs_dir - if ca_cert is not None: - config.ca_cert = ca_cert - if ca_key is not None: - config.ca_key = ca_key - if server_cert is not None: - config.server_cert = server_cert - if server_key is not None: - config.server_key = server_key + if force_certs is not None: + config.force_certs = force_certs # Set the __config class attribute Config.__config = config diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index b96ed07b..720fc2af 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -187,7 +187,7 @@ def get_last_user_message_idx(request: ChatCompletionRequest) -> int: if request.get("messages") is None: return -1 - for idx, message in reversed(list(enumerate(request['messages']))): + for idx, message in reversed(list(enumerate(request["messages"]))): if message.get("role", "") == "user": return idx diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index fb333e42..533d70cf 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -110,9 +110,7 @@ async def process( return PipelineResult(request=request) # Look for matches in vector DB using list of packages as filter - searched_objects = await self.get_objects_from_search( - user_messages, ecosystem, packages - ) + searched_objects = await self.get_objects_from_search(user_messages, ecosystem, packages) logger.info( f"Found {len(searched_objects)} matches in the database", @@ -149,4 +147,3 @@ async def process( message["content"] = context_msg return PipelineResult(request=new_request, context=context) - diff --git a/tests/test_cli.py b/tests/test_cli.py index bbe699f0..8557a63c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,24 +1,23 @@ """Tests for the server module.""" import os -from unittest.mock import MagicMock, patch, AsyncMock +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock, patch import pytest +from click.testing import CliRunner from fastapi.middleware.cors import CORSMiddleware from fastapi.testclient import TestClient from httpx import AsyncClient +from uvicorn.config import Config as UvicornConfig from codegate import __version__ from codegate.pipeline.factory import PipelineFactory from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.registry import ProviderRegistry from codegate.server import init_app -from src.codegate.cli import UvicornServer -from src.codegate.cli import cli -from src.codegate.codegate_logging import LogLevel, LogFormat -from uvicorn.config import Config as UvicornConfig -from click.testing import CliRunner -from pathlib import Path +from src.codegate.cli import UvicornServer, cli +from src.codegate.codegate_logging import LogFormat, LogLevel @pytest.fixture @@ -176,12 +175,12 @@ def mock_app(): @pytest.fixture def uvicorn_config(mock_app): # Assuming mock_app is defined to simulate ASGI application - return UvicornConfig(app=mock_app, host='localhost', port=8000, log_level='info') + return UvicornConfig(app=mock_app, host="localhost", port=8000, log_level="info") @pytest.fixture def server_instance(uvicorn_config): - with patch('src.codegate.cli.Server', autospec=True) as mock_server_class: + with patch("src.codegate.cli.Server", autospec=True) as mock_server_class: mock_server_instance = mock_server_class.return_value mock_server_instance.serve = AsyncMock() yield UvicornServer(uvicorn_config, mock_server_instance) @@ -200,20 +199,22 @@ def cli_runner(): @pytest.fixture def mock_logging(mocker): - return mocker.patch('your_cli_module.structlog.get_logger') + return mocker.patch("your_cli_module.structlog.get_logger") @pytest.fixture def mock_setup_logging(mocker): - return mocker.patch('your_cli_module.setup_logging') + return mocker.patch("your_cli_module.setup_logging") def test_serve_default_options(cli_runner): """Test serve command with default options.""" # Use patches for run_servers and logging setup - with patch("src.codegate.cli.run_servers") as mock_run, \ - patch("src.codegate.cli.structlog.get_logger") as mock_logging, \ - patch("src.codegate.cli.setup_logging") as mock_setup_logging: + with ( + patch("src.codegate.cli.run_servers") as mock_run, + patch("src.codegate.cli.structlog.get_logger") as mock_logging, + patch("src.codegate.cli.setup_logging") as mock_setup_logging, + ): logger_instance = MagicMock() mock_logging.return_value = logger_instance @@ -236,9 +237,11 @@ def test_serve_default_options(cli_runner): def test_serve_custom_options(cli_runner): """Test serve command with custom options.""" - with patch("src.codegate.cli.run_servers") as mock_run, \ - patch("src.codegate.cli.structlog.get_logger") as mock_logging, \ - patch("src.codegate.cli.setup_logging") as mock_setup_logging: + with ( + patch("src.codegate.cli.run_servers") as mock_run, + patch("src.codegate.cli.structlog.get_logger") as mock_logging, + patch("src.codegate.cli.setup_logging") as mock_setup_logging, + ): logger_instance = MagicMock() mock_logging.return_value = logger_instance @@ -248,15 +251,24 @@ def test_serve_custom_options(cli_runner): cli, [ "serve", - "--port", "8989", - "--host", "localhost", - "--log-level", "DEBUG", - "--log-format", "TEXT", - "--certs-dir", "./custom-certs", - "--ca-cert", "custom-ca.crt", - "--ca-key", "custom-ca.key", - "--server-cert", "custom-server.crt", - "--server-key", "custom-server.key", + "--port", + "8989", + "--host", + "localhost", + "--log-level", + "DEBUG", + "--log-format", + "TEXT", + "--certs-dir", + "./custom-certs", + "--ca-cert", + "custom-ca.crt", + "--ca-key", + "custom-ca.key", + "--server-cert", + "custom-server.crt", + "--server-key", + "custom-server.key", ], ) @@ -289,8 +301,9 @@ def test_serve_custom_options(cli_runner): # Check if Config object attributes match the expected values for key, expected_value in expected_values.items(): - assert getattr(config_arg, key) == expected_value, \ - f"{key} does not match expected value" + assert ( + getattr(config_arg, key) == expected_value + ), f"{key} does not match expected value" def test_serve_invalid_port(cli_runner): @@ -310,21 +323,25 @@ def test_serve_invalid_log_level(cli_runner): @pytest.fixture def temp_config_file(tmp_path): config_path = tmp_path / "config.yaml" - config_path.write_text(""" + config_path.write_text( + """ log_level: DEBUG log_format: JSON port: 8989 host: localhost certs_dir: ./test-certs - """) + """ + ) return config_path def test_serve_with_config_file(cli_runner, temp_config_file): """Test serve command with config file.""" - with patch("src.codegate.cli.run_servers") as mock_run, \ - patch("src.codegate.cli.structlog.get_logger") as mock_logging, \ - patch("src.codegate.cli.setup_logging") as mock_setup_logging: + with ( + patch("src.codegate.cli.run_servers") as mock_run, + patch("src.codegate.cli.structlog.get_logger") as mock_logging, + patch("src.codegate.cli.setup_logging") as mock_setup_logging, + ): logger_instance = MagicMock() mock_logging.return_value = logger_instance @@ -352,8 +369,9 @@ def test_serve_with_config_file(cli_runner, temp_config_file): # Check if passed arguments match the expected values for key, expected_value in expected_values.items(): - assert getattr(config_arg, key) == expected_value, \ - f"{key} does not match expected value" + assert ( + getattr(config_arg, key) == expected_value + ), f"{key} does not match expected value" def test_serve_with_nonexistent_config_file(cli_runner: CliRunner) -> None: @@ -366,10 +384,12 @@ def test_serve_with_nonexistent_config_file(cli_runner: CliRunner) -> None: def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path) -> None: """Test serve command respects configuration priority.""" # Set up environment variables and ensure they get cleaned up after the test - with patch.dict(os.environ, {'LOG_LEVEL': 'INFO', 'PORT': '9999'}, clear=True), \ - patch('src.codegate.cli.run_servers') as mock_run, \ - patch('src.codegate.cli.structlog.get_logger') as mock_logging, \ - patch('src.codegate.cli.setup_logging') as mock_setup_logging: + with ( + patch.dict(os.environ, {"LOG_LEVEL": "INFO", "PORT": "9999"}, clear=True), + patch("src.codegate.cli.run_servers") as mock_run, + patch("src.codegate.cli.structlog.get_logger") as mock_logging, + patch("src.codegate.cli.setup_logging") as mock_setup_logging, + ): # Set up mock logger logger_instance = MagicMock() mock_logging.return_value = logger_instance @@ -406,7 +426,7 @@ def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path assert result.exit_code == 0 # Ensure logging setup was called with the highest priority settings (CLI arguments) - mock_setup_logging.assert_called_once_with('ERROR', 'TEXT') + mock_setup_logging.assert_called_once_with("ERROR", "TEXT") mock_logging.assert_called_with("codegate") # Verify that the run_servers was called with the overridden settings @@ -415,8 +435,8 @@ def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path expected_values = { "port": 8080, "host": "example.com", - "log_level": 'ERROR', - "log_format": 'TEXT', + "log_level": "ERROR", + "log_format": "TEXT", "certs_dir": "./cli-certs", "ca_cert": "cli-ca.crt", "ca_key": "cli-ca.key", @@ -426,15 +446,18 @@ def test_serve_priority_resolution(cli_runner: CliRunner, temp_config_file: Path # Verify if Config object attributes match the expected values from CLI arguments for key, expected_value in expected_values.items(): - assert getattr(config_arg, key) == expected_value, \ - f"{key} does not match expected value" + assert ( + getattr(config_arg, key) == expected_value + ), f"{key} does not match expected value" def test_serve_certificate_options(cli_runner: CliRunner) -> None: """Test serve command with certificate options.""" - with patch('src.codegate.cli.run_servers') as mock_run, \ - patch('src.codegate.cli.structlog.get_logger') as mock_logging, \ - patch('src.codegate.cli.setup_logging') as mock_setup_logging: + with ( + patch("src.codegate.cli.run_servers") as mock_run, + patch("src.codegate.cli.structlog.get_logger") as mock_logging, + patch("src.codegate.cli.setup_logging") as mock_setup_logging, + ): # Set up mock logger logger_instance = MagicMock() mock_logging.return_value = logger_instance @@ -461,7 +484,7 @@ def test_serve_certificate_options(cli_runner: CliRunner) -> None: assert result.exit_code == 0 # Ensure logging setup was called with expected arguments - mock_setup_logging.assert_called_once_with('INFO', 'JSON') + mock_setup_logging.assert_called_once_with("INFO", "JSON") mock_logging.assert_called_with("codegate") # Verify that run_servers was called with the provided certificate options @@ -477,14 +500,16 @@ def test_serve_certificate_options(cli_runner: CliRunner) -> None: # Check if Config object attributes match the expected values for key, expected_value in expected_values.items(): - assert getattr(config_arg, key) == expected_value, \ - f"{key} does not match expected value" + assert ( + getattr(config_arg, key) == expected_value + ), f"{key} does not match expected value" def test_main_function() -> None: """Test main function.""" with patch("sys.argv", ["cli"]), patch("codegate.cli.cli") as mock_cli: from codegate.cli import main + main() mock_cli.assert_called_once() @@ -501,8 +526,10 @@ def mock_uvicorn_server(): @pytest.mark.asyncio async def test_uvicorn_server_cleanup(mock_uvicorn_server): - with patch("asyncio.get_running_loop"), \ - patch.object(mock_uvicorn_server.server, 'shutdown', AsyncMock()): + with ( + patch("asyncio.get_running_loop"), + patch.object(mock_uvicorn_server.server, "shutdown", AsyncMock()), + ): # Mock the loop or other components as needed # Start the server or trigger the condition you want to test