Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions scripts/import_packages.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import argparse
import asyncio
import json
import os
import shutil
import argparse


import weaviate
from weaviate.classes.config import DataType, Property
Expand Down Expand Up @@ -134,7 +133,8 @@ async def run_import(self):

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Run the package importer with optional backup flags.")
description="Run the package importer with optional backup flags."
)
parser.add_argument(
"--take-backup",
type=lambda x: x.lower() == "true",
Expand Down
58 changes: 37 additions & 21 deletions src/codegate/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,44 +210,35 @@ async def process(
pass


class SequentialPipelineProcessor:
def __init__(self, pipeline_steps: List[PipelineStep]):
class InputPipelineInstance:
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager):
self.pipeline_steps = pipeline_steps
self.secret_manager = secret_manager
self.context = PipelineContext()

async def process_request(
self,
secret_manager: SecretsManager,
request: ChatCompletionRequest,
provider: str,
prompt_id: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> PipelineResult:
"""
Process a request through all pipeline steps

Args:
secret_manager: The secrets manager instance to gather sensitive data from the request
request: The chat completion request to process

Returns:
PipelineResult containing either a modified request or response structure
"""
context = PipelineContext()
context.sensitive = PipelineSensitiveData(
manager=secret_manager,
"""Process a request through all pipeline steps"""
self.context.sensitive = PipelineSensitiveData(
manager=self.secret_manager,
session_id=str(uuid.uuid4()),
api_key=api_key,
model=model,
provider=provider,
api_base=api_base,
) # Generate a new session ID for each request
context.metadata["prompt_id"] = prompt_id
)
self.context.metadata["prompt_id"] = prompt_id
current_request = request

for step in self.pipeline_steps:
result = await step.process(current_request, context)
result = await step.process(current_request, self.context)
if result is None:
continue

Expand All @@ -258,6 +249,31 @@ async def process_request(
current_request = result.request

if result.context is not None:
context = result.context
self.context = result.context

return PipelineResult(request=current_request, context=self.context)


class SequentialPipelineProcessor:
def __init__(self, pipeline_steps: List[PipelineStep], secret_manager: SecretsManager):
self.pipeline_steps = pipeline_steps
self.secret_manager = secret_manager

def create_instance(self) -> InputPipelineInstance:
"""Create a new pipeline instance for processing a request"""
return InputPipelineInstance(self.pipeline_steps, self.secret_manager)

return PipelineResult(request=current_request, context=context)
async def process_request(
self,
request: ChatCompletionRequest,
provider: str,
prompt_id: str,
model: str,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> PipelineResult:
"""Create a new pipeline instance and process the request"""
instance = self.create_instance()
return await instance.process_request(
request, provider, prompt_id, model, api_key, api_base
)
5 changes: 5 additions & 0 deletions src/codegate/pipeline/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@
from dataclasses import dataclass, field
from typing import AsyncIterator, List, Optional

import structlog
from litellm import ModelResponse
from litellm.types.utils import Delta, StreamingChoices

from codegate.pipeline.base import CodeSnippet, PipelineContext

logger = structlog.get_logger("codegate")


@dataclass
class OutputPipelineContext:
Expand Down Expand Up @@ -131,12 +134,14 @@ async def process_stream(

# Yield all processed chunks
for c in current_chunks:
logger.debug(f"Yielding chunk {c}")
self._store_chunk_content(c)
self._context.buffer.clear()
yield c

except Exception as e:
# Log exception and stop processing
logger.error(f"Error processing stream: {e}")
raise e
finally:
# Process any remaining content in buffer when stream ends
Expand Down
20 changes: 12 additions & 8 deletions src/codegate/providers/anthropic/adapter.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
from typing import Optional

import litellm
from litellm import ChatCompletionRequest
from litellm.adapters.anthropic_adapter import (
AnthropicAdapter as LitellmAnthropicAdapter,
)
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
)

from codegate.providers.litellmshim.adapter import (
LiteLLMAdapterInputNormalizer,
LiteLLMAdapterOutputNormalizer,
)
import litellm
from litellm import ChatCompletionRequest
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
)


class AnthropicAdapter(LitellmAnthropicAdapter):
Expand All @@ -22,10 +23,13 @@ def translate_completion_input_params(self, kwargs) -> Optional[ChatCompletionRe
request_body = AnthropicMessagesRequest(**kwargs) # type: ignore
if not request_body.get("system"):
request_body["system"] = "System prompt"
translated_body = litellm.AnthropicExperimentalPassThroughConfig()\
.translate_anthropic_to_openai(anthropic_message_request=request_body)
translated_body = (
litellm.AnthropicExperimentalPassThroughConfig().translate_anthropic_to_openai(
anthropic_message_request=request_body
)
)
return translated_body


class AnthropicInputNormalizer(LiteLLMAdapterInputNormalizer):
"""
Expand Down
8 changes: 4 additions & 4 deletions src/codegate/providers/anthropic/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,31 @@

from fastapi import Header, HTTPException, Request

from codegate.pipeline.base import SequentialPipelineProcessor
from codegate.pipeline.output import OutputPipelineProcessor
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer
from codegate.providers.anthropic.completion_handler import AnthropicCompletion
from codegate.providers.base import BaseProvider, SequentialPipelineProcessor
from codegate.providers.base import BaseProvider
from codegate.providers.litellmshim import anthropic_stream_generator


class AnthropicProvider(BaseProvider):
def __init__(
self,
secrets_manager: SecretsManager,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
):
completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator)
super().__init__(
secrets_manager,
AnthropicInputNormalizer(),
AnthropicOutputNormalizer(),
completion_handler,
pipeline_processor,
fim_pipeline_processor,
output_pipeline_processor,
fim_output_pipeline_processor,
)

@property
Expand Down
51 changes: 39 additions & 12 deletions src/codegate/providers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
from litellm.types.llms.openai import ChatCompletionRequest

from codegate.db.connection import DbRecorder
from codegate.pipeline.base import PipelineContext, PipelineResult, SequentialPipelineProcessor
from codegate.pipeline.base import (
PipelineContext,
PipelineResult,
SequentialPipelineProcessor,
)
from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor
from codegate.pipeline.secrets.manager import SecretsManager
from codegate.providers.completion.base import BaseCompletionHandler
from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter
from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer
from codegate.providers.normalizer.completion import CompletionNormalizer

logger = structlog.get_logger("codegate")

Expand All @@ -28,26 +32,27 @@ class BaseProvider(ABC):

def __init__(
self,
secrets_manager: Optional[SecretsManager],
input_normalizer: ModelInputNormalizer,
output_normalizer: ModelOutputNormalizer,
completion_handler: BaseCompletionHandler,
pipeline_processor: Optional[SequentialPipelineProcessor] = None,
fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None,
output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None,
):
self.router = APIRouter()
self._secrets_manager = secrets_manager
self._completion_handler = completion_handler
self._input_normalizer = input_normalizer
self._output_normalizer = output_normalizer
self._pipeline_processor = pipeline_processor
self._fim_pipelin_processor = fim_pipeline_processor
self._output_pipeline_processor = output_pipeline_processor
self._fim_output_pipeline_processor = fim_output_pipeline_processor
self._db_recorder = DbRecorder()
self._pipeline_response_formatter = PipelineResponseFormatter(
output_normalizer, self._db_recorder
)
self._fim_normalizer = CompletionNormalizer()

self._setup_routes()

Expand All @@ -63,13 +68,33 @@ def provider_route_name(self) -> str:
async def _run_output_stream_pipeline(
self,
input_context: PipelineContext,
normalized_stream: AsyncIterator[ModelResponse],
model_stream: AsyncIterator[ModelResponse],
is_fim_request: bool,
) -> AsyncIterator[ModelResponse]:
# Decide which pipeline processor to use
out_pipeline_processor = None
if is_fim_request:
out_pipeline_processor = self._fim_output_pipeline_processor
logger.info("FIM pipeline selected for output.")
else:
out_pipeline_processor = self._output_pipeline_processor
logger.info("Chat completion pipeline selected for output.")
if out_pipeline_processor is None:
logger.info("No output pipeline processor found, passing through")
return model_stream
if len(out_pipeline_processor.pipeline_steps) == 0:
logger.info("No output pipeline steps configured, passing through")
return model_stream

normalized_stream = self._output_normalizer.normalize_streaming(model_stream)

output_pipeline_instance = OutputPipelineInstance(
self._output_pipeline_processor.pipeline_steps,
pipeline_steps=out_pipeline_processor.pipeline_steps,
input_context=input_context,
)
return output_pipeline_instance.process_stream(normalized_stream)
pipeline_output_stream = output_pipeline_instance.process_stream(normalized_stream)
denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream)
return denormalized_stream

def _run_output_pipeline(
self,
Expand All @@ -90,14 +115,14 @@ async def _run_input_pipeline(
if is_fim_request:
pipeline_processor = self._fim_pipelin_processor
logger.info("FIM pipeline selected for execution.")
normalized_request = self._fim_normalizer.normalize(normalized_request)
else:
pipeline_processor = self._pipeline_processor
logger.info("Chat completion pipeline selected for execution.")
if pipeline_processor is None:
return PipelineResult(request=normalized_request)

result = await pipeline_processor.process_request(
secret_manager=self._secrets_manager,
request=normalized_request,
provider=self.provider_route_name,
prompt_id=prompt_id,
Expand Down Expand Up @@ -208,10 +233,13 @@ async def complete(
)

provider_request = self._input_normalizer.denormalize(input_pipeline_result.request)
if is_fim_request:
provider_request = self._fim_normalizer.denormalize(provider_request)

# Execute the completion and translate the response
# This gives us either a single response or a stream of responses
# based on the streaming flag
logger.info(f"Executing completion with {provider_request}")
model_response = await self._completion_handler.execute_completion(
provider_request, api_key=api_key, stream=streaming, is_fim_request=is_fim_request
)
Expand All @@ -230,13 +258,12 @@ async def complete(
return self._output_normalizer.denormalize(pipeline_output)

model_response = self._db_recorder.record_output_stream(prompt_db, model_response)
normalized_stream = self._output_normalizer.normalize_streaming(model_response)
pipeline_output_stream = await self._run_output_stream_pipeline(
input_pipeline_result.context,
normalized_stream,
model_response,
is_fim_request=is_fim_request,
)
denormalized_stream = self._output_normalizer.denormalize_streaming(pipeline_output_stream)
return self._cleanup_after_streaming(denormalized_stream, input_pipeline_result.context)
return self._cleanup_after_streaming(pipeline_output_stream, input_pipeline_result.context)

def get_routes(self) -> APIRouter:
return self.router
19 changes: 1 addition & 18 deletions src/codegate/providers/llamacpp/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,29 +18,12 @@ def normalize(self, data: Dict) -> ChatCompletionRequest:
Normalize the input data
"""
# Make a copy of the data to avoid modifying the original and normalize the message content
normalized_data = self._normalize_content_messages(data)

# When doing FIM, we receive "prompt" instead of messages. Normalizing.
if "prompt" in normalized_data:
normalized_data["messages"] = [
{"content": normalized_data.pop("prompt"), "role": "user"}
]
# We can add as many parameters as we like to data. ChatCompletionRequest is not strict.
normalized_data["had_prompt_before"] = True
try:
return ChatCompletionRequest(**normalized_data)
except Exception as e:
raise ValueError(f"Invalid completion parameters: {str(e)}")
return self._normalize_content_messages(data)

def denormalize(self, data: ChatCompletionRequest) -> Dict:
"""
Denormalize the input data
"""
# If we receive "prompt" in FIM, we need convert it back.
if data.get("had_prompt_before", False):
data["prompt"] = data["messages"][0]["content"]
del data["had_prompt_before"]
del data["messages"]
return data


Expand Down
Loading