Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import uuid
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

from litellm import ChatCompletionRequest

from codegate.pipeline.secrets.manager import SecretsManager


@dataclass
class CodeSnippet:
Expand All @@ -24,10 +27,25 @@ def __post_init__(self):
self.language = self.language.strip().lower()


@dataclass
class PipelineSensitiveData:
manager: SecretsManager
session_id: str

def secure_cleanup(self):
"""Securely cleanup sensitive data for this session"""
if self.manager is None or self.session_id == "":
return

self.manager.cleanup_session(self.session_id)
self.session_id = ""


@dataclass
class PipelineContext:
code_snippets: List[CodeSnippet] = field(default_factory=list)
metadata: Dict[str, Any] = field(default_factory=dict)
sensitive: Optional[PipelineSensitiveData] = field(default_factory=lambda: None)

def add_code_snippet(self, snippet: CodeSnippet):
self.code_snippets.append(snippet)
Expand Down Expand Up @@ -139,18 +157,23 @@ def __init__(self, pipeline_steps: List[PipelineStep]):

async def process_request(
self,
secret_manager: SecretsManager,
request: ChatCompletionRequest,
) -> PipelineResult:
"""
Process a request through all pipeline steps

Args:
request: The chat completion request to process
secret_manager: The secrets manager instance to gather sensitive data from the request

Returns:
PipelineResult containing either a modified request or response structure
"""
context = PipelineContext()
context.sensitive = PipelineSensitiveData(
manager=secret_manager, session_id=str(uuid.uuid4())
) # Generate a new session ID for each request
current_request = request

for step in self.pipeline_steps:
Expand Down
173 changes: 173 additions & 0 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import AsyncIterator, Optional

from litellm import ModelResponse
from litellm.types.utils import Delta, StreamingChoices

from codegate.pipeline.base import PipelineContext


@dataclass
class OutputPipelineContext:
"""
Context passed between output pipeline steps.

Does not include the input context, that one is separate.
"""

# We store the messages that are not yet sent to the client in the buffer.
# One reason for this might be that the buffer contains a secret that we want to de-obfuscate
buffer: list[str] = field(default_factory=list)


class OutputPipelineStep(ABC):
"""
Base class for output pipeline steps
The process method should be implemented by subclasses and handles
processing of a single chunk of the stream.
"""

@property
@abstractmethod
def name(self) -> str:
"""Returns the name of this pipeline step"""
pass

@abstractmethod
async def process_chunk(
self,
chunk: ModelResponse,
context: OutputPipelineContext,
input_context: Optional[PipelineContext] = None,
) -> Optional[ModelResponse]:
"""
Process a single chunk of the stream.

Args:
- chunk: The input chunk to process, normalized to ModelResponse
- context: The output pipeline context. Can be used to store state between steps, mainly
the buffer.
- input_context: The input context from processing the user's input. Can include the secrets
obfuscated in the user message or code snippets in the user message.

Return:
- None to pause the stream
- Modified or unmodified input chunk to pass through
"""
pass


class OutputPipelineInstance:
"""
Handles processing of a single stream
Think of this class as steps + buffer
"""

def __init__(
self,
pipeline_steps: list[OutputPipelineStep],
input_context: Optional[PipelineContext] = None,
):
self._input_context = input_context
self._pipeline_steps = pipeline_steps
self._context = OutputPipelineContext()
# we won't actually buffer the chunk, but in case we need to pass
# the remaining content in the buffer when the stream ends, we need
# to store the parameters like model, timestamp, etc.
self._buffered_chunk = None

def _buffer_chunk(self, chunk: ModelResponse) -> None:
"""
Add chunk content to buffer.
"""
self._buffered_chunk = chunk
for choice in chunk.choices:
# the last choice has no delta or content, let's not buffer it
if choice.delta is not None and choice.delta.content is not None:
self._context.buffer.append(choice.delta.content)

async def process_stream(
self, stream: AsyncIterator[ModelResponse]
) -> AsyncIterator[ModelResponse]:
"""
Process a stream through all pipeline steps
"""
try:
async for chunk in stream:
# Store chunk content in buffer
self._buffer_chunk(chunk)

# Process chunk through each step of the pipeline
current_chunk = chunk
for step in self._pipeline_steps:
if current_chunk is None:
# Stop processing if a step returned None previously
# this means that the pipeline step requested to pause the stream
# instead, let's try again with the next chunk
break

processed_chunk = await step.process_chunk(
current_chunk, self._context, self._input_context
)
# the returned chunk becomes the input for the next chunk in the pipeline
current_chunk = processed_chunk

# we have either gone through all the steps in the pipeline and have a chunk
# to return or we are paused in which case we don't yield
if current_chunk is not None:
# Step processed successfully, yield the chunk and clear buffer
self._context.buffer.clear()
yield current_chunk
# else: keep buffering for next iteration

except Exception as e:
# Log exception and stop processing
raise e
finally:
# Process any remaining content in buffer when stream ends
if self._context.buffer:
final_content = "".join(self._context.buffer)
yield ModelResponse(
id=self._buffered_chunk.id,
choices=[
StreamingChoices(
finish_reason=None,
# we just put one choice in the buffer, so 0 is fine
index=0,
delta=Delta(content=final_content, role="assistant"),
# umm..is this correct?
logprobs=self._buffered_chunk.choices[0].logprobs,
)
],
created=self._buffered_chunk.created,
model=self._buffered_chunk.model,
object="chat.completion.chunk",
)
self._context.buffer.clear()

# Cleanup sensitive data through the input context
if self._input_context and self._input_context.sensitive:
self._input_context.sensitive.secure_cleanup()


class OutputPipelineProcessor:
"""
Since we want to provide each run of the pipeline with a fresh context,
we need a factory to create new instances of the pipeline.
"""

def __init__(self, pipeline_steps: list[OutputPipelineStep]):
self.pipeline_steps = pipeline_steps

def _create_instance(self) -> OutputPipelineInstance:
"""Create a new pipeline instance for processing a stream"""
return OutputPipelineInstance(self.pipeline_steps)

async def process_stream(
self, stream: AsyncIterator[ModelResponse]
) -> AsyncIterator[ModelResponse]:
"""Create a new pipeline instance and process the stream"""
instance = self._create_instance()
async for chunk in instance.process_stream(stream):
yield chunk
112 changes: 112 additions & 0 deletions src/codegate/pipeline/secrets/manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from typing import NamedTuple, Optional

import structlog

from codegate.pipeline.secrets.gatecrypto import CodeGateCrypto

logger = structlog.get_logger("codegate")


class SecretEntry(NamedTuple):
"""Represents a stored secret"""

original: str
encrypted: str
service: str
secret_type: str


class SecretsManager:
"""Manages encryption, storage and retrieval of secrets"""

def __init__(self):
self.crypto = CodeGateCrypto()
self._session_store: dict[str, SecretEntry] = {}
self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index

def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str:
"""
Encrypts and stores a secret value.
Returns the encrypted value.
"""
if not value:
raise ValueError("Value must be provided")
if not service:
raise ValueError("Service must be provided")
if not secret_type:
raise ValueError("Secret type must be provided")
if not session_id:
raise ValueError("Session ID must be provided")

encrypted_value = self.crypto.encrypt_token(value, session_id)

# Store mappings
self._session_store[session_id] = SecretEntry(
original=value,
encrypted=encrypted_value,
service=service,
secret_type=secret_type,
)
self._encrypted_to_session[encrypted_value] = session_id

logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value)

return encrypted_value

def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[str]:
"""Retrieve original value for an encrypted value"""
try:
stored_session_id = self._encrypted_to_session.get(encrypted_value)
if stored_session_id == session_id:
return self._session_store[session_id].original
except Exception as e:
logger.error("Error retrieving secret", error=str(e))
return None

def get_by_session_id(self, session_id: str) -> Optional[SecretEntry]:
"""Get stored data by session ID"""
return self._session_store.get(session_id)

def cleanup(self):
"""Securely wipe sensitive data"""
try:
# Convert and wipe original values
for entry in self._session_store.values():
original_bytes = bytearray(entry.original.encode())
self.crypto.wipe_bytearray(original_bytes)

# Clear the dictionaries
self._session_store.clear()
self._encrypted_to_session.clear()

logger.info("Secrets manager data securely wiped")
except Exception as e:
logger.error("Error during secure cleanup", error=str(e))

def cleanup_session(self, session_id: str):
"""
Remove a specific session's secrets and perform secure cleanup.

Args:
session_id (str): The session identifier to remove
"""
try:
# Get the secret entry for the session
entry = self._session_store.get(session_id)

if entry:
# Securely wipe the original value
original_bytes = bytearray(entry.original.encode())
self.crypto.wipe_bytearray(original_bytes)

# Remove the encrypted value from the reverse lookup index
self._encrypted_to_session.pop(entry.encrypted, None)

# Remove the session from the store
self._session_store.pop(session_id, None)

logger.debug("Session secrets securely removed", session_id=session_id)
else:
logger.debug("No secrets found for session", session_id=session_id)
except Exception as e:
logger.error("Error during session cleanup", session_id=session_id, error=str(e))
Loading