From d94bfad9754e74c42cb5d258095ca28d5d23d0c8 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 2 Dec 2024 17:30:10 +0100 Subject: [PATCH 1/6] Add an output pipeline Adds an output pipeline with its own context and a possibility to buffer in case a step tells the pipeline to not forward the chunk. The buffer is stored into in case a chunk is not evaluated which means that subsequent steps can keep on buffering potentially redacted text for un-redaction later. The remainder of the buffer is flushed upon consuming the stream. --- src/codegate/pipeline/output.py | 169 +++++++++++++++++++ tests/pipeline/test_output.py | 289 ++++++++++++++++++++++++++++++++ 2 files changed, 458 insertions(+) create mode 100644 src/codegate/pipeline/output.py create mode 100644 tests/pipeline/test_output.py diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py new file mode 100644 index 00000000..09239644 --- /dev/null +++ b/src/codegate/pipeline/output.py @@ -0,0 +1,169 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import AsyncIterator, Optional + +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext + + +@dataclass +class OutputPipelineContext: + """ + Context passed between output pipeline steps. + + Does not include the input context, that one is separate. + """ + + # We store the messages that are not yet sent to the client in the buffer. + # One reason for this might be that the buffer contains a secret that we want to de-obfuscate + buffer: list[str] = field(default_factory=list) + + +class OutputPipelineStep(ABC): + """ + Base class for output pipeline steps + The process method should be implemented by subclasses and handles + processing of a single chunk of the stream. + """ + + @property + @abstractmethod + def name(self) -> str: + """Returns the name of this pipeline step""" + pass + + @abstractmethod + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: Optional[PipelineContext] = None, + ) -> Optional[ModelResponse]: + """ + Process a single chunk of the stream. + + Args: + - chunk: The input chunk to process, normalized to ModelResponse + - context: The output pipeline context. Can be used to store state between steps, mainly + the buffer. + - input_context: The input context from processing the user's input. Can include the secrets + obfuscated in the user message or code snippets in the user message. + + Return: + - None to pause the stream + - Modified or unmodified input chunk to pass through + """ + pass + + +class OutputPipelineInstance: + """ + Handles processing of a single stream + Think of this class as steps + buffer + """ + + def __init__( + self, + pipeline_steps: list[OutputPipelineStep], + input_context: Optional[PipelineContext] = None, + ): + self._input_context = input_context + self._pipeline_steps = pipeline_steps + self._context = OutputPipelineContext() + # we won't actually buffer the chunk, but in case we need to pass + # the remaining content in the buffer when the stream ends, we need + # to store the parameters like model, timestamp, etc. + self._buffered_chunk = None + + def _buffer_chunk(self, chunk: ModelResponse) -> None: + """ + Add chunk content to buffer. + """ + self._buffered_chunk = chunk + for choice in chunk.choices: + # the last choice has no delta or content, let's not buffer it + if choice.delta is not None and choice.delta.content is not None: + self._context.buffer.append(choice.delta.content) + + async def process_stream( + self, stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + """ + Process a stream through all pipeline steps + """ + try: + async for chunk in stream: + # Store chunk content in buffer + self._buffer_chunk(chunk) + + # Process chunk through each step of the pipeline + current_chunk = chunk + for step in self._pipeline_steps: + if current_chunk is None: + # Stop processing if a step returned None previously + # this means that the pipeline step requested to pause the stream + # instead, let's try again with the next chunk + break + + processed_chunk = await step.process_chunk( + current_chunk, self._context, self._input_context + ) + # the returned chunk becomes the input for the next chunk in the pipeline + current_chunk = processed_chunk + + # we have either gone through all the steps in the pipeline and have a chunk + # to return or we are paused in which case we don't yield + if current_chunk is not None: + # Step processed successfully, yield the chunk and clear buffer + self._context.buffer.clear() + yield current_chunk + # else: keep buffering for next iteration + + except Exception as e: + # Log exception and stop processing + raise e + finally: + # Process any remaining content in buffer when stream ends + if self._context.buffer: + final_content = "".join(self._context.buffer) + yield ModelResponse( + id=self._buffered_chunk.id, + choices=[ + StreamingChoices( + finish_reason=None, + # we just put one choice in the buffer, so 0 is fine + index=0, + delta=Delta(content=final_content, role="assistant"), + # umm..is this correct? + logprobs=self._buffered_chunk.choices[0].logprobs, + ) + ], + created=self._buffered_chunk.created, + model=self._buffered_chunk.model, + object="chat.completion.chunk", + ) + self._context.buffer.clear() + + +class OutputPipelineProcessor: + """ + Since we want to provide each run of the pipeline with a fresh context, + we need a factory to create new instances of the pipeline. + """ + + def __init__(self, pipeline_steps: list[OutputPipelineStep]): + self.pipeline_steps = pipeline_steps + + def _create_instance(self) -> OutputPipelineInstance: + """Create a new pipeline instance for processing a stream""" + return OutputPipelineInstance(self.pipeline_steps) + + async def process_stream( + self, stream: AsyncIterator[ModelResponse] + ) -> AsyncIterator[ModelResponse]: + """Create a new pipeline instance and process the stream""" + instance = self._create_instance() + async for chunk in instance.process_stream(stream): + yield chunk diff --git a/tests/pipeline/test_output.py b/tests/pipeline/test_output.py new file mode 100644 index 00000000..eeb42085 --- /dev/null +++ b/tests/pipeline/test_output.py @@ -0,0 +1,289 @@ +from typing import Optional + +import pytest +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext +from codegate.pipeline.output import ( + OutputPipelineContext, + OutputPipelineInstance, + OutputPipelineStep, +) + + +class MockOutputPipelineStep(OutputPipelineStep): + """Mock pipeline step for testing""" + + def __init__(self, name: str, should_pause: bool = False, modify_content: bool = False): + self._name = name + self._should_pause = should_pause + self._modify_content = modify_content + + @property + def name(self) -> str: + return self._name + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> ModelResponse: + if self._should_pause: + return None + + if self._modify_content and chunk.choices[0].delta.content: + # Append step name to content to track modifications + modified_content = f"{chunk.choices[0].delta.content}_{self.name}" + chunk.choices[0].delta.content = modified_content + + return chunk + + +def create_model_response(content: str, id: str = "test") -> ModelResponse: + """Helper to create test ModelResponse objects""" + return ModelResponse( + id=id, + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ], + created=0, + model="test-model", + object="chat.completion.chunk", + ) + + +class TestOutputPipelineContext: + def test_buffer_initialization(self): + """Test that buffer is properly initialized""" + context = OutputPipelineContext() + assert isinstance(context.buffer, list) + assert len(context.buffer) == 0 + + def test_buffer_operations(self): + """Test adding and clearing buffer content""" + context = OutputPipelineContext() + context.buffer.append("test1") + context.buffer.append("test2") + + assert len(context.buffer) == 2 + assert context.buffer == ["test1", "test2"] + + context.buffer.clear() + assert len(context.buffer) == 0 + + +class TestOutputPipelineInstance: + @pytest.mark.asyncio + async def test_single_step_processing(self): + """Test processing a stream through a single step""" + step = MockOutputPipelineStep("test_step", modify_content=True) + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello_test_step" + assert chunks[1].choices[0].delta.content == "World_test_step" + # Buffer should be cleared after each successful chunk + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_multiple_steps_processing(self): + """Test processing a stream through multiple steps""" + steps = [ + MockOutputPipelineStep("step1", modify_content=True), + MockOutputPipelineStep("step2", modify_content=True), + ] + instance = OutputPipelineInstance(steps) + + async def mock_stream(): + yield create_model_response("Hello") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 1 + # Content should be modified by both steps + assert chunks[0].choices[0].delta.content == "Hello_step1_step2" + # Buffer should be cleared after successful processing + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_step_pausing(self): + """Test that a step can pause the stream and content is buffered until flushed""" + steps = [ + MockOutputPipelineStep("step1", should_pause=True), + MockOutputPipelineStep("step2", modify_content=True), + ] + instance = OutputPipelineInstance(steps) + + async def mock_stream(): + yield create_model_response("he") + yield create_model_response("ll") + yield create_model_response("o") + yield create_model_response(" wo") + yield create_model_response("rld") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk at the end with all buffered content + assert len(chunks) == 1 + # Content should be buffered and combined + assert chunks[0].choices[0].delta.content == "hello world" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_step_pausing_with_replacement(self): + """Test that a step can pause the stream and modify the buffered content before flushing""" + + class ReplacementStep(OutputPipelineStep): + """Step that replaces 'world' with 'moon' when found in buffer""" + + def __init__(self, should_pause: bool = True): + self._should_pause = should_pause + + @property + def name(self) -> str: + return "replacement" + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> Optional[ModelResponse]: + # Replace 'world' with 'moon' in buffered content + content = "".join(context.buffer) + if "world" in content: + content = content.replace("world", "moon") + chunk.choices = [ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ] + return chunk + return None + + instance = OutputPipelineInstance([ReplacementStep()]) + + async def mock_stream(): + yield create_model_response("he") + yield create_model_response("ll") + yield create_model_response("o") + yield create_model_response("wo") + yield create_model_response("rld") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk at the end with modified content + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "hellomoon" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_buffer_processing(self): + """Test that content is properly buffered and cleared""" + step = MockOutputPipelineStep("test_step") + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + # Buffer should be cleared after each successful chunk + assert len(instance._context.buffer) == 0 + + assert len(chunks) == 2 + assert chunks[0].choices[0].delta.content == "Hello" + assert chunks[1].choices[0].delta.content == "World" + + @pytest.mark.asyncio + async def test_empty_stream(self): + """Test handling of an empty stream""" + step = MockOutputPipelineStep("test_step") + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + if False: + yield # Empty stream + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + assert len(chunks) == 0 + assert len(instance._context.buffer) == 0 + + @pytest.mark.asyncio + async def test_input_context_passing(self): + """Test that input context is properly passed to steps""" + input_context = PipelineContext() + input_context.metadata["test"] = "value" + + class ContextCheckingStep(OutputPipelineStep): + @property + def name(self) -> str: + return "context_checker" + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: PipelineContext = None, + ) -> ModelResponse: + assert input_context.metadata["test"] == "value" + return chunk + + instance = OutputPipelineInstance([ContextCheckingStep()], input_context=input_context) + + async def mock_stream(): + yield create_model_response("test") + + async for _ in instance.process_stream(mock_stream()): + pass + + @pytest.mark.asyncio + async def test_buffer_flush_on_stream_end(self): + """Test that buffer is properly flushed when stream ends""" + step = MockOutputPipelineStep("test_step", should_pause=True) + instance = OutputPipelineInstance([step]) + + async def mock_stream(): + yield create_model_response("Hello") + yield create_model_response("World") + + chunks = [] + async for chunk in instance.process_stream(mock_stream()): + chunks.append(chunk) + + # Should get one chunk with combined buffer content + assert len(chunks) == 1 + assert chunks[0].choices[0].delta.content == "HelloWorld" + # Buffer should be cleared after flush + assert len(instance._context.buffer) == 0 From f106c9ef472242e84bca0fdfa7ab78133bf41976 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 2 Dec 2024 17:26:47 +0100 Subject: [PATCH 2/6] Split out parts of the CodegateSecrets pipeline step into a secret manager We'll want to decrypt secrets on the way from the LLM. For that we need to reuse parts of functionality that were so far in the secret encryption step. This commit splits them into a secrets manager. --- src/codegate/pipeline/secrets/manager.py | 112 +++++++++++++ src/codegate/pipeline/secrets/secrets.py | 192 +++++------------------ tests/pipeline/secrets/test_manager.py | 148 +++++++++++++++++ 3 files changed, 299 insertions(+), 153 deletions(-) create mode 100644 src/codegate/pipeline/secrets/manager.py create mode 100644 tests/pipeline/secrets/test_manager.py diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py new file mode 100644 index 00000000..a7b32319 --- /dev/null +++ b/src/codegate/pipeline/secrets/manager.py @@ -0,0 +1,112 @@ +from typing import NamedTuple, Optional + +import structlog + +from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto + +logger = structlog.get_logger("codegate") + + +class SecretEntry(NamedTuple): + """Represents a stored secret""" + + original: str + encrypted: str + service: str + secret_type: str + + +class SecretsManager: + """Manages encryption, storage and retrieval of secrets""" + + def __init__(self): + self.crypto = CodeGateCrypto() + self._session_store: dict[str, SecretEntry] = {} + self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index + + def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: + """ + Encrypts and stores a secret value. + Returns the encrypted value. + """ + if not value: + raise ValueError("Value must be provided") + if not service: + raise ValueError("Service must be provided") + if not secret_type: + raise ValueError("Secret type must be provided") + if not session_id: + raise ValueError("Session ID must be provided") + + encrypted_value = self.crypto.encrypt_token(value, session_id) + + # Store mappings + self._session_store[session_id] = SecretEntry( + original=value, + encrypted=encrypted_value, + service=service, + secret_type=secret_type, + ) + self._encrypted_to_session[encrypted_value] = session_id + + logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) + + return encrypted_value + + def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[str]: + """Retrieve original value for an encrypted value""" + try: + stored_session_id = self._encrypted_to_session.get(encrypted_value) + if stored_session_id == session_id: + return self._session_store[session_id].original + except Exception as e: + logger.error("Error retrieving secret", error=str(e)) + return None + + def get_by_session_id(self, session_id: str) -> Optional[SecretEntry]: + """Get stored data by session ID""" + return self._session_store.get(session_id) + + def cleanup(self): + """Securely wipe sensitive data""" + try: + # Convert and wipe original values + for entry in self._session_store.values(): + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) + + # Clear the dictionaries + self._session_store.clear() + self._encrypted_to_session.clear() + + logger.info("Secrets manager data securely wiped") + except Exception as e: + logger.error("Error during secure cleanup", error=str(e)) + + def cleanup_session(self, session_id: str): + """ + Remove a specific session's secrets and perform secure cleanup. + + Args: + session_id (str): The session identifier to remove + """ + try: + # Get the secret entry for the session + entry = self._session_store.get(session_id) + + if entry: + # Securely wipe the original value + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) + + # Remove the encrypted value from the reverse lookup index + self._encrypted_to_session.pop(entry.encrypted, None) + + # Remove the session from the store + self._session_store.pop(session_id, None) + + logger.debug("Session secrets securely removed", session_id=session_id) + else: + logger.debug("No secrets found for session", session_id=session_id) + except Exception as e: + logger.error("Error during session cleanup", session_id=session_id, error=str(e)) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index f9c3d0df..8e6ee181 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -1,5 +1,3 @@ -import re - import structlog from litellm import ChatCompletionRequest @@ -8,7 +6,7 @@ PipelineResult, PipelineStep, ) -from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto +from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures logger = structlog.get_logger("codegate") @@ -20,9 +18,6 @@ class CodegateSecrets(PipelineStep): def __init__(self): """Initialize the CodegateSecrets pipeline step.""" super().__init__() - self.crypto = CodeGateCrypto() - self._session_store = {} - self._encrypted_to_session = {} # Reverse lookup index @property def name(self) -> str: @@ -73,7 +68,7 @@ def _extend_match_boundaries(self, text: str, start: int, end: int) -> tuple[int return start, end - def _redeact_text(self, text: str) -> str: + def _redeact_text(self, text: str, secrets_manager: SecretsManager, session_id: str) -> str: """ Find and encrypt secrets in the given text. @@ -114,34 +109,19 @@ def _redeact_text(self, text: str) -> str: # Replace each match with its encrypted value for start, end, match in absolute_matches: - # Generate session key and encrypt the value - session_id = self.crypto.generate_session_key(None).hex() - encrypted_value = self.crypto.encrypt_token(match.value, session_id) - - print("Original value: ", match.value) - print("Encrypted value: ", encrypted_value) - print("Service: ", match.service) - print("Type: ", match.type) - - # Store the mapping - self._session_store[session_id] = { - "original": match.value, - "encrypted": encrypted_value, - "service": match.service, - "type": match.type, - } - # Store reverse lookup - self._encrypted_to_session[encrypted_value] = session_id - - # Print the session store - logger.info(f"Session store: {self._session_store}") + # Encrypt and store the value + encrypted_value = secrets_manager.store_secret( + match.value, + match.service, + match.type, + session_id, + ) # Create the replacement string replacement = f"REDACTED<${encrypted_value}>" # Replace the secret in the text protected_text[start:end] = replacement - # Store for logging found_secrets.append( { @@ -152,110 +132,20 @@ def _redeact_text(self, text: str) -> str: } ) - # Convert back to string - protected_string = "".join(protected_text) - - # Log the findings - logger.info("\nFound secrets:") - for secret in found_secrets: - logger.info(f"\nService: {secret['service']}") - logger.info(f"Type: {secret['type']}") - logger.info(f"Original: {secret['original']}") - logger.info(f"Encrypted: REDACTED<${secret['encrypted']}>") - - (f"\nProtected text:\n{protected_string}") - return protected_string + # Convert back to string + protected_string = "".join(protected_text) - def _get_original_value(self, encrypted_value: str) -> str: - """ - Get the original value for an encrypted value from the session store. + # Log the findings + logger.info("\nFound secrets:") - Args: - encrypted_value: The encrypted value to look up + for secret in found_secrets: + logger.info(f"\nService: {secret['service']}") + logger.info(f"Type: {secret['type']}") + logger.info(f"Original: {secret['original']}") + logger.info(f"Encrypted: REDACTED<${secret['encrypted']}>") - Returns: - Original value if found, or the encrypted value if not found - """ - try: - # Use reverse lookup index to get session_id - session_id = self._encrypted_to_session.get(encrypted_value) - if session_id: - return self._session_store[session_id]["original"] - except Exception as e: - logger.error(f"Error looking up original value: {e}") - return encrypted_value - - def get_by_session_id(self, session_id: str) -> dict | None: - """ - Get stored data directly by session ID. - - Args: - session_id: The session ID to look up - - Returns: - Dict containing the stored data if found, None otherwise - """ - try: - return self._session_store.get(session_id) - except Exception as e: - logger.error(f"Error looking up by session ID: {e}") - return None - - def _cleanup_session_store(self): - """ - Securely wipe sensitive data from session stores. - """ - try: - # Convert and wipe original values - for session_data in self._session_store.values(): - if "original" in session_data: - original_bytes = bytearray(session_data["original"].encode()) - self.crypto.wipe_bytearray(original_bytes) - - # Clear the dictionaries - self._session_store.clear() - self._encrypted_to_session.clear() - - logger.info("Session stores securely wiped") - except Exception as e: - logger.error(f"Error during secure cleanup: {e}") - - def _unredact_text(self, protected_text: str) -> str: - """ - Decrypt and restore the original text from protected text. - - Args: - protected_text: The protected text containing encrypted values - - Returns: - Original text with decrypted values - """ - # Find all REDACTED markers - pattern = r"REDACTED<\$([^>]+)>" - - # Start from the beginning of the text - result = [] - last_end = 0 - - # Find each REDACTED section and replace with original value - for match in re.finditer(pattern, protected_text): - # Add text before this match - result.append(protected_text[last_end : match.start()]) - - # Get and add the original value - encrypted_value = match.group(1) - original_value = self._get_original_value(encrypted_value) - result.append(original_value) - - last_end = match.end() - - # Add any remaining text - result.append(protected_text[last_end:]) - - # Join all parts together - unprotected_text = "".join(result) - logger.info(f"\nUnprotected text:\n{unprotected_text}") - return unprotected_text + print(f"\nProtected text:\n{protected_string}") + return "".join(protected_text) async def process( self, request: ChatCompletionRequest, context: PipelineContext @@ -270,32 +160,28 @@ async def process( Returns: PipelineResult containing the processed request """ + secrets_manager = context.sensitive.manager + if not secrets_manager or not isinstance(secrets_manager, SecretsManager): + # Should this be an error? + raise ValueError("Secrets manager not found in context") + session_id = context.sensitive.session_id + if not session_id: + raise ValueError("Session ID not found in context") + last_user_message = self.get_last_user_message(request) - extracted_string = last_user_message[0] if last_user_message else None - print(f"Original text:\n{extracted_string}") + extracted_string = None + extracted_index = None + if last_user_message: + extracted_string = last_user_message[0] + extracted_index = last_user_message[1] if not extracted_string: return PipelineResult(request=request) - try: - # Protect the text - protected_string = self._redeact_text(extracted_string) - print(f"\nProtected text:\n{protected_string}") - - # LLM - unprotected_string = self._unredact_text(protected_string) - print(f"\nUnprotected text:\n{unprotected_string}") - - # Update the user message with protected text - if isinstance(request["messages"], list): - for msg in request["messages"]: - if msg.get("role") == "user" and msg.get("content") == extracted_string: - msg["content"] = protected_string - - return PipelineResult(request=request) - except Exception as e: - logger.error(f"CodegateSecrets operation failed: {e}") + # Protect the text + protected_string = self._redeact_text(extracted_string, secrets_manager, session_id) - finally: - # Clean up sensitive data - self._cleanup_session_store() + # Update the user message + new_request = request.copy() + new_request["messages"][extracted_index]["content"] = protected_string + return PipelineResult(request=new_request) diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py new file mode 100644 index 00000000..5cb06ade --- /dev/null +++ b/tests/pipeline/secrets/test_manager.py @@ -0,0 +1,148 @@ +import pytest + +from codegate.pipeline.secrets.manager import SecretEntry, SecretsManager + + +class TestSecretsManager: + def setup_method(self): + """Setup a fresh SecretsManager for each test""" + self.manager = SecretsManager() + self.test_session = "test_session_id" + self.test_value = "super_secret_value" + self.test_service = "test_service" + self.test_type = "api_key" + + def test_store_secret(self): + """Test basic secret storage and retrieval""" + # Store a secret + encrypted = self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Verify the secret was stored + stored = self.manager.get_by_session_id(self.test_session) + assert isinstance(stored, SecretEntry) + assert stored.original == self.test_value + assert stored.encrypted == encrypted + assert stored.service == self.test_service + assert stored.secret_type == self.test_type + + # Verify encrypted value can be retrieved + retrieved = self.manager.get_original_value(encrypted, self.test_session) + assert retrieved == self.test_value + + def test_get_original_value_wrong_session(self): + """Test that secrets can't be accessed with wrong session ID""" + encrypted = self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Try to retrieve with wrong session ID + wrong_session = "wrong_session_id" + retrieved = self.manager.get_original_value(encrypted, wrong_session) + assert retrieved is None + + def test_get_original_value_nonexistent(self): + """Test handling of non-existent encrypted values""" + retrieved = self.manager.get_original_value("nonexistent", self.test_session) + assert retrieved is None + + def test_cleanup_session(self): + """Test that session cleanup properly removes secrets""" + # Store multiple secrets in different sessions + session1 = "session1" + session2 = "session2" + + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + + # Clean up session1 + self.manager.cleanup_session(session1) + + # Verify session1 secrets are gone + assert self.manager.get_by_session_id(session1) is None + assert self.manager.get_original_value(encrypted1, session1) is None + + # Verify session2 secrets remain + assert self.manager.get_by_session_id(session2) is not None + assert self.manager.get_original_value(encrypted2, session2) == "secret2" + + def test_cleanup(self): + """Test that cleanup properly wipes all data""" + # Store multiple secrets + self.manager.store_secret("secret1", "service1", "type1", "session1") + self.manager.store_secret("secret2", "service2", "type2", "session2") + + # Perform cleanup + self.manager.cleanup() + + # Verify all data is wiped + assert len(self.manager._session_store) == 0 + assert len(self.manager._encrypted_to_session) == 0 + + def test_multiple_secrets_same_session(self): + """Test storing multiple secrets in the same session""" + # Store multiple secrets in same session + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) + + # Latest secret should be retrievable + stored = self.manager.get_by_session_id(self.test_session) + assert stored.original == "secret2" + assert stored.encrypted == encrypted2 + + # Both encrypted values should map to the session + assert self.manager._encrypted_to_session[encrypted1] == self.test_session + assert self.manager._encrypted_to_session[encrypted2] == self.test_session + + def test_error_handling(self): + """Test error handling in secret operations""" + # Test with None values + with pytest.raises(ValueError): + self.manager.store_secret(None, self.test_service, self.test_type, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, None, self.test_type, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, self.test_service, None, self.test_session) + + with pytest.raises(ValueError): + self.manager.store_secret(self.test_value, self.test_service, self.test_type, None) + + def test_secure_cleanup(self): + """Test that cleanup securely wipes sensitive data""" + # Store a secret + self.manager.store_secret( + self.test_value, self.test_service, self.test_type, self.test_session + ) + + # Get reference to stored data before cleanup + stored = self.manager.get_by_session_id(self.test_session) + original_value = stored.original + + # Perform cleanup + self.manager.cleanup() + + # Verify the original string was overwritten, not just removed + # This test is a bit tricky since Python strings are immutable, + # but we can at least verify the data is no longer accessible + assert original_value not in str(self.manager._session_store) + assert self.test_value not in str(self.manager._session_store) + + def test_session_isolation(self): + """Test that sessions are properly isolated""" + session1 = "session1" + session2 = "session2" + + # Store secrets in different sessions + encrypted1 = self.manager.store_secret("secret1", "service1", "type1", session1) + encrypted2 = self.manager.store_secret("secret2", "service2", "type2", session2) + + # Verify cross-session access is not possible + assert self.manager.get_original_value(encrypted1, session2) is None + assert self.manager.get_original_value(encrypted2, session1) is None + + # Verify correct session access works + assert self.manager.get_original_value(encrypted1, session1) == "secret1" + assert self.manager.get_original_value(encrypted2, session2) == "secret2" From fc683d33934a2dfd3ed124ac0504eb6cb3366a9f Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 2 Dec 2024 17:29:55 +0100 Subject: [PATCH 3/6] Implement secret unredaction step --- src/codegate/pipeline/secrets/secrets.py | 95 +++++++++++++++++++++++- 1 file changed, 94 insertions(+), 1 deletion(-) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 8e6ee181..265d6b86 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -1,11 +1,16 @@ +import re +from typing import Optional + import structlog -from litellm import ChatCompletionRequest +from litellm import ChatCompletionRequest, ModelResponse +from litellm.types.utils import Delta, StreamingChoices from codegate.pipeline.base import ( PipelineContext, PipelineResult, PipelineStep, ) +from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.signatures import CodegateSignatures @@ -185,3 +190,91 @@ async def process( new_request = request.copy() new_request["messages"][extracted_index]["content"] = protected_string return PipelineResult(request=new_request) + + +class SecretUnredactionStep(OutputPipelineStep): + """Pipeline step that unredacts protected content in the stream""" + + def __init__(self): + self.redacted_pattern = re.compile(r"REDACTED<\$([^>]+)>") + self.marker_start = "REDACTED<$" + self.marker_end = ">" + + @property + def name(self) -> str: + return "secret-unredaction" + + def _is_partial_marker_prefix(self, text: str) -> bool: + """Check if text ends with a partial marker prefix""" + for i in range(1, len(self.marker_start) + 1): + if text.endswith(self.marker_start[:i]): + return True + return False + + def _find_complete_redaction(self, text: str) -> tuple[Optional[re.Match[str]], str]: + """ + Find the first complete REDACTED marker in text. + Returns (match, remaining_text) if found, (None, original_text) if not. + """ + matches = list(self.redacted_pattern.finditer(text)) + if not matches: + return None, text + + # Get the first complete match + match = matches[0] + return match, text[match.end() :] + + async def process_chunk( + self, + chunk: ModelResponse, + context: OutputPipelineContext, + input_context: Optional[PipelineContext] = None, + ) -> Optional[ModelResponse]: + """Process a single chunk of the stream""" + if input_context.sensitive is None or input_context.sensitive.manager is None: + raise ValueError("Secrets manager not found in input context") + if input_context.sensitive.session_id == "": + raise ValueError("Session ID not found in input context") + + if not chunk.choices[0].delta.content: + return chunk + + # Check the buffered content + buffered_content = "".join(context.buffer) + + # Look for complete REDACTED markers first + match, remaining = self._find_complete_redaction(buffered_content) + if match: + # Found a complete marker, process it + encrypted_value = match.group(1) + original_value = input_context.sensitive.manager.get_original_value( + encrypted_value, + input_context.sensitive.session_id, + ) + + if original_value is None: + # If value not found, leave as is + original_value = match.group(0) # Keep the REDACTED marker + + # Return the unredacted content up to this point + chunk.choices = [ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta( + content=buffered_content[: match.start()] + original_value + remaining, + role="assistant", + ), + logprobs=None, + ) + ] + return chunk + + # If we have a partial marker at the end, keep buffering + if self.marker_start in buffered_content or self._is_partial_marker_prefix( + buffered_content + ): + return None + + # No markers or partial markers, let pipeline handle the chunk normally + return chunk From b5ff345a50298257c227ab5c31f227ace4d12287 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 2 Dec 2024 17:33:18 +0100 Subject: [PATCH 4/6] Plug in the secrets manager and output pipeline --- src/codegate/pipeline/base.py | 15 ++++++ src/codegate/pipeline/secrets/secrets.py | 2 + src/codegate/providers/anthropic/provider.py | 6 +++ src/codegate/providers/base.py | 27 +++++++--- src/codegate/providers/llamacpp/provider.py | 6 +++ src/codegate/providers/ollama/provider.py | 6 +++ src/codegate/providers/openai/provider.py | 6 +++ src/codegate/providers/vllm/provider.py | 6 +++ src/codegate/server.py | 53 ++++++++++++++++--- .../providers/ollama/test_ollama_provider.py | 2 +- 10 files changed, 116 insertions(+), 13 deletions(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index dc7bac53..7ce07b1b 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -1,9 +1,12 @@ +import uuid from abc import ABC, abstractmethod from dataclasses import dataclass, field from typing import Any, Dict, List, Optional from litellm import ChatCompletionRequest +from codegate.pipeline.secrets.manager import SecretsManager + @dataclass class CodeSnippet: @@ -24,10 +27,17 @@ def __post_init__(self): self.language = self.language.strip().lower() +@dataclass +class PipelineSensitiveData: + manager: SecretsManager + session_id: str + + @dataclass class PipelineContext: code_snippets: List[CodeSnippet] = field(default_factory=list) metadata: Dict[str, Any] = field(default_factory=dict) + sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None) def add_code_snippet(self, snippet: CodeSnippet): self.code_snippets.append(snippet) @@ -139,6 +149,7 @@ def __init__(self, pipeline_steps: List[PipelineStep]): async def process_request( self, + secret_manager: SecretsManager, request: ChatCompletionRequest, ) -> PipelineResult: """ @@ -146,11 +157,15 @@ async def process_request( Args: request: The chat completion request to process + secret_manager: The secrets manager instance to gather sensitive data from the request Returns: PipelineResult containing either a modified request or response structure """ context = PipelineContext() + context.sensitive = PipelineSensitiveData( + manager=secret_manager, session_id=str(uuid.uuid4()) + ) # Generate a new session ID for each request current_request = request for step in self.pipeline_steps: diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 265d6b86..782211a6 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -231,6 +231,8 @@ async def process_chunk( input_context: Optional[PipelineContext] = None, ) -> Optional[ModelResponse]: """Process a single chunk of the stream""" + if not input_context: + raise ValueError("Input context not found") if input_context.sensitive is None or input_context.sensitive.manager is None: raise ValueError("Secrets manager not found in input context") if input_context.sensitive.session_id == "": diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 4d7eba59..32909260 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -3,6 +3,8 @@ from fastapi import Header, HTTPException, Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider, SequentialPipelineProcessor @@ -12,16 +14,20 @@ class AnthropicProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( + secrets_manager, AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 585c1d9a..51c6a353 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -7,7 +7,9 @@ from litellm.types.llms.openai import ChatCompletionRequest from codegate.db.connection import DbRecorder -from codegate.pipeline.base import PipelineResult, SequentialPipelineProcessor +from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor +from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -25,18 +27,22 @@ class BaseProvider(ABC): def __init__( self, + secrets_manager: Optional[SecretsManager], input_normalizer: ModelInputNormalizer, output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): self.router = APIRouter() + self._secrets_manager = secrets_manager self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer self._pipeline_processor = pipeline_processor self._fim_pipelin_processor = fim_pipeline_processor + self._output_pipeline_processor = output_pipeline_processor self._pipeline_response_formatter = PipelineResponseFormatter(output_normalizer) self.db_recorder = DbRecorder() @@ -53,16 +59,20 @@ def provider_route_name(self) -> str: async def _run_output_stream_pipeline( self, + input_context: PipelineContext, normalized_stream: AsyncIterator[ModelResponse], ) -> AsyncIterator[ModelResponse]: - # we don't have a pipeline for output stream yet - return normalized_stream + output_pipeline_instance = OutputPipelineInstance( + self._output_pipeline_processor.pipeline_steps, + input_context=input_context, + ) + return output_pipeline_instance.process_stream(normalized_stream) def _run_output_pipeline( self, normalized_response: ModelResponse, ) -> ModelResponse: - # we don't have a pipeline for output yet + # we don't have a pipeline for non-streamed output yet return normalized_response async def _run_input_pipeline( @@ -78,7 +88,9 @@ async def _run_input_pipeline( if pipeline_processor is None: return PipelineResult(request=normalized_request) - result = await pipeline_processor.process_request(normalized_request) + result = await pipeline_processor.process_request( + secret_manager=self._secrets_manager, request=normalized_request + ) # TODO(jakub): handle this by returning a message to the client if result.error_message: @@ -175,7 +187,10 @@ async def complete( return self._output_normalizer.denormalize(pipeline_output) normalized_stream = self._output_normalizer.normalize_streaming(model_response) - pipeline_output_stream = await self._run_output_stream_pipeline(normalized_stream) + pipeline_output_stream = await self._run_output_stream_pipeline( + input_pipeline_result.context, + normalized_stream, + ) return self._output_normalizer.denormalize_streaming(pipeline_output_stream) def get_routes(self) -> APIRouter: diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index efe06f09..d97feb79 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -3,6 +3,8 @@ from fastapi import Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -11,16 +13,20 @@ class LlamaCppProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LlamaCppCompletionHandler() super().__init__( + secrets_manager, LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 5b8c9a4b..95c7fea8 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -4,6 +4,8 @@ from fastapi import Request from codegate.config import Config +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaCompletionHandler @@ -12,16 +14,20 @@ class OllamaProvider(BaseProvider): def __init__( self, + secrets_manager: Optional[SecretsManager], pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = OllamaCompletionHandler() super().__init__( + secrets_manager, OllamaInputNormalizer(), OllamaOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) # Get the Ollama base URL config = Config.get_config() diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 741d3143..649805a9 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -3,6 +3,8 @@ from fastapi import Header, HTTPException, Request +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -11,16 +13,20 @@ class OpenAIProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( + secrets_manager, OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index a342ac6f..242ce05f 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -5,6 +5,8 @@ from litellm import atext_completion from codegate.config import Config +from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.secrets.manager import SecretsManager from codegate.providers.base import BaseProvider, SequentialPipelineProcessor from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -13,18 +15,22 @@ class VLLMProvider(BaseProvider): def __init__( self, + secrets_manager: SecretsManager, pipeline_processor: Optional[SequentialPipelineProcessor] = None, fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, + output_pipeline_processor: Optional[OutputPipelineProcessor] = None, ): completion_handler = LiteLLmShim( stream_generator=sse_stream_generator, fim_completion_func=atext_completion ) super().__init__( + secrets_manager, VLLMInputNormalizer(), VLLMOutputNormalizer(), completion_handler, pipeline_processor, fim_pipeline_processor, + output_pipeline_processor, ) @property diff --git a/src/codegate/server.py b/src/codegate/server.py index f8a953f4..a45e6b65 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -8,7 +8,9 @@ from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever from codegate.pipeline.codegate_system_prompt.codegate import CodegateSystemPrompt from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippetExtractor -from codegate.pipeline.secrets.secrets import CodegateSecrets +from codegate.pipeline.output import OutputPipelineProcessor, OutputPipelineStep +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.secrets.secrets import CodegateSecrets, SecretUnredactionStep from codegate.pipeline.secrets.signatures import CodegateSignatures from codegate.pipeline.version.version import CodegateVersion from codegate.providers.anthropic.provider import AnthropicProvider @@ -27,6 +29,12 @@ def init_app() -> FastAPI: version=__version__, ) + # Initialize secrets manager + # TODO: we need to clean up the secrets manager + # after the conversation is concluded + # this was done in the pipeline step but I just removed it for now + secrets_manager = SecretsManager() + steps: List[PipelineStep] = [ CodegateVersion(), CodeSnippetExtractor(), @@ -39,6 +47,11 @@ def init_app() -> FastAPI: pipeline = SequentialPipelineProcessor(steps) fim_pipeline = SequentialPipelineProcessor(fim_steps) + output_steps: List[OutputPipelineStep] = [ + SecretUnredactionStep(), + ] + output_pipeline = OutputPipelineProcessor(output_steps) + # Create provider registry registry = ProviderRegistry(app) @@ -47,21 +60,49 @@ def init_app() -> FastAPI: # Register all known providers registry.add_provider( - "openai", OpenAIProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "openai", + OpenAIProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( "anthropic", - AnthropicProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + AnthropicProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( "llamacpp", - LlamaCppProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline), + LlamaCppProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( - "vllm", VLLMProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "vllm", + VLLMProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) registry.add_provider( - "ollama", OllamaProvider(pipeline_processor=pipeline, fim_pipeline_processor=fim_pipeline) + "ollama", + OllamaProvider( + secrets_manager=secrets_manager, + pipeline_processor=pipeline, + fim_pipeline_processor=fim_pipeline, + output_pipeline_processor=output_pipeline, + ), ) # Create and add system routes diff --git a/tests/providers/ollama/test_ollama_provider.py b/tests/providers/ollama/test_ollama_provider.py index 5fd5cf4e..ed10e7fd 100644 --- a/tests/providers/ollama/test_ollama_provider.py +++ b/tests/providers/ollama/test_ollama_provider.py @@ -18,7 +18,7 @@ def __init__(self): def app(): """Create FastAPI app with Ollama provider.""" app = FastAPI() - provider = OllamaProvider() + provider = OllamaProvider(None) app.include_router(provider.get_routes()) return app From 92910c06b8cb52b8ac2d4cd3f0da9a2610694d96 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 2 Dec 2024 19:31:16 +0100 Subject: [PATCH 5/6] Clean up the secure pipeline after the completion is concluded --- src/codegate/pipeline/base.py | 8 ++++++++ src/codegate/pipeline/output.py | 4 ++++ src/codegate/providers/base.py | 15 ++++++++++++++- 3 files changed, 26 insertions(+), 1 deletion(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 7ce07b1b..f5b62c8c 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -32,6 +32,14 @@ class PipelineSensitiveData: manager: SecretsManager session_id: str + def secure_cleanup(self): + """Securely cleanup sensitive data for this session""" + if self.manager is None or self.session_id == "": + return + + self.manager.cleanup_session(self.session_id) + self.session_id = "" + @dataclass class PipelineContext: diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 09239644..b74ce2d0 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -146,6 +146,10 @@ async def process_stream( ) self._context.buffer.clear() + # Cleanup sensitive data through the input context + if self._input_context and self._input_context.sensitive: + self._input_context.sensitive.secure_cleanup() + class OutputPipelineProcessor: """ diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index 51c6a353..509f8e9c 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -147,6 +147,18 @@ def _is_fim_request(self, request: Request, data: Dict) -> bool: return self._is_fim_request_body(data) + async def _cleanup_after_streaming( + self, stream: AsyncIterator[ModelResponse], context: PipelineContext + ) -> AsyncIterator[ModelResponse]: + """Wraps the stream to ensure cleanup after consumption""" + try: + async for item in stream: + yield item + finally: + # Ensure sensitive data is cleaned up after the stream is consumed + if context and context.sensitive: + context.sensitive.secure_cleanup() + async def complete( self, data: Dict, api_key: Optional[str], is_fim_request: bool ) -> Union[ModelResponse, AsyncIterator[ModelResponse]]: @@ -191,7 +203,8 @@ async def complete( input_pipeline_result.context, normalized_stream, ) - return self._output_normalizer.denormalize_streaming(pipeline_output_stream) + denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream) + return self._cleanup_after_streaming(denormalized_stream, input_pipeline_result.context) def get_routes(self) -> APIRouter: return self.router From 5ada6b471286ba9a1b56dbdce0a60a8264155803 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Tue, 3 Dec 2024 00:20:17 +0100 Subject: [PATCH 6/6] Unit test the de-obfuscation of secrets --- tests/pipeline/secrets/test_secrets.py | 147 +++++++++++++++++++++++++ 1 file changed, 147 insertions(+) create mode 100644 tests/pipeline/secrets/test_secrets.py diff --git a/tests/pipeline/secrets/test_secrets.py b/tests/pipeline/secrets/test_secrets.py new file mode 100644 index 00000000..52be4eaf --- /dev/null +++ b/tests/pipeline/secrets/test_secrets.py @@ -0,0 +1,147 @@ +import pytest +from litellm import ModelResponse +from litellm.types.utils import Delta, StreamingChoices + +from codegate.pipeline.base import PipelineContext, PipelineSensitiveData +from codegate.pipeline.output import OutputPipelineContext +from codegate.pipeline.secrets.manager import SecretsManager +from codegate.pipeline.secrets.secrets import SecretUnredactionStep + + +def create_model_response(content: str) -> ModelResponse: + """Helper to create test ModelResponse objects""" + return ModelResponse( + id="test", + choices=[ + StreamingChoices( + finish_reason=None, + index=0, + delta=Delta(content=content, role="assistant"), + logprobs=None, + ) + ], + created=0, + model="test-model", + object="chat.completion.chunk", + ) + + +class TestSecretUnredactionStep: + def setup_method(self): + """Setup fresh instances for each test""" + self.step = SecretUnredactionStep() + self.context = OutputPipelineContext() + self.secrets_manager = SecretsManager() + self.session_id = "test_session" + + # Setup input context with secrets manager + self.input_context = PipelineContext() + self.input_context.sensitive = PipelineSensitiveData( + manager=self.secrets_manager, session_id=self.session_id + ) + + @pytest.mark.asyncio + async def test_complete_marker_processing(self): + """Test processing of a complete REDACTED marker""" + # Store a secret + encrypted = self.secrets_manager.store_secret( + "secret_value", "test_service", "api_key", self.session_id + ) + + # Add content with REDACTED marker to buffer + self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + + # Process a chunk + result = await self.step.process_chunk( + create_model_response("more text"), self.context, self.input_context + ) + + # Verify unredaction + assert result is not None + assert result.choices[0].delta.content == "Here is the secret_value in text" + + @pytest.mark.asyncio + async def test_partial_marker_buffering(self): + """Test handling of partial REDACTED markers""" + # Add partial marker to buffer + self.context.buffer.append("Here is REDACTED<$") + + # Process a chunk + result = await self.step.process_chunk( + create_model_response("partial"), self.context, self.input_context + ) + + # Should return None to continue buffering + assert result is None + + @pytest.mark.asyncio + async def test_invalid_encrypted_value(self): + """Test handling of invalid encrypted values""" + # Add content with invalid encrypted value + self.context.buffer.append("Here is REDACTED<$invalid_value> in text") + + # Process chunk + result = await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + # Should keep the REDACTED marker for invalid values + assert result is not None + assert result.choices[0].delta.content == "Here is REDACTED<$invalid_value> in text" + + @pytest.mark.asyncio + async def test_missing_context(self): + """Test handling of missing input context or secrets manager""" + # Test with None input context + with pytest.raises(ValueError, match="Input context not found"): + await self.step.process_chunk(create_model_response("text"), self.context, None) + + # Test with missing secrets manager + self.input_context.sensitive.manager = None + with pytest.raises(ValueError, match="Secrets manager not found in input context"): + await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + @pytest.mark.asyncio + async def test_empty_content(self): + """Test handling of empty content chunks""" + result = await self.step.process_chunk( + create_model_response(""), self.context, self.input_context + ) + + # Should pass through empty chunks + assert result is not None + assert result.choices[0].delta.content == "" + + @pytest.mark.asyncio + async def test_no_markers(self): + """Test processing of content without any REDACTED markers""" + # Create chunk with content + chunk = create_model_response("Regular text without any markers") + + # Process chunk + result = await self.step.process_chunk(chunk, self.context, self.input_context) + + # Should pass through unchanged + assert result is not None + assert result.choices[0].delta.content == "Regular text without any markers" + + @pytest.mark.asyncio + async def test_wrong_session(self): + """Test unredaction with wrong session ID""" + # Store secret with one session + encrypted = self.secrets_manager.store_secret( + "secret_value", "test_service", "api_key", "different_session" + ) + + # Try to unredact with different session + self.context.buffer.append(f"Here is the REDACTED<${encrypted}> in text") + + result = await self.step.process_chunk( + create_model_response("text"), self.context, self.input_context + ) + + # Should keep REDACTED marker when session doesn't match + assert result is not None + assert result.choices[0].delta.content == f"Here is the REDACTED<${encrypted}> in text"