Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions src/codegate/providers/copilot/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def process_body(
result = await self.instance.process_request(
request=normalized_body,
provider=self.provider_name,
model=normalized_body.get("model", "gpt-4o-mini"),
model=normalized_body.model, # TODO: There was a default value here of gpt-4o-mini. Retain?
api_key=headers_dict.get("authorization", "").replace("Bearer ", ""),
api_base="https://" + headers_dict.get("host", ""),
extra_headers=CopilotPipeline._get_copilot_headers(headers_dict),
Expand All @@ -122,7 +122,7 @@ async def process_body(
try:
# Return shortcut response to the user
body = CopilotPipeline._create_shortcut_response(
result, normalized_body.get("model", "gpt-4o-mini")
result, normalized_body.model,
)
logger.info(f"Pipeline created shortcut response: {body}")
return body, result.context
Expand Down Expand Up @@ -155,12 +155,13 @@ def __init__(self):
self._completion_normalizer = CompletionNormalizer()

def normalize(self, body: bytes) -> ChatCompletionRequest:
json_body = json.loads(body)
return self._completion_normalizer.normalize(json_body)
return ChatCompletionRequest.model_validate_json(body)

def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
normalized_json_body = self._completion_normalizer.denormalize(request_from_pipeline)
return json.dumps(normalized_json_body).encode()
return request_from_pipeline.model_dump_json(
exclude_none=True,
exclude_unset=True,
).encode('utf-8')


class CopilotChatNormalizer:
Expand All @@ -171,19 +172,19 @@ class CopilotChatNormalizer:
"""

def normalize(self, body: bytes) -> ChatCompletionRequest:
json_body = json.loads(body)
normalized_data = ChatCompletionRequest(**json_body)
return ChatCompletionRequest.model_validate_json(body)

# This would normally be the required to get the token usage with OpenAI models.
# However the response comes back empty with Copilot. Commenting for the moment.
# It's not critical since Copilot charges a fixed rate and not based in tokens.
# if normalized_data.get("stream", False):
# normalized_data["stream_options"] = {"include_usage": True}

return normalized_data

def denormalize(self, request_from_pipeline: ChatCompletionRequest) -> bytes:
return json.dumps(request_from_pipeline).encode()
return request_from_pipeline.model_dump_json(
exclude_none=True,
exclude_unset=True,
).encode('utf-8')


class CopilotFimPipeline(CopilotPipeline):
Expand Down
40 changes: 6 additions & 34 deletions src/codegate/providers/copilot/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ModelResponse,
StreamingChoices,
)
from codegate.types.openai import StreamingChatCompletion

setup_logging()
logger = structlog.get_logger("codegate").bind(origin="copilot_proxy")
Expand Down Expand Up @@ -816,7 +817,7 @@ def __init__(self, proxy: CopilotProvider):
self.headers_sent = False
self.sse_processor: Optional[SSEProcessor] = None
self.output_pipeline_instance: Optional[OutputPipelineInstance] = None
self.stream_queue: Optional[asyncio.Queue] = None
self.stream_queue: Optional[asyncio.Queue[StreamingChatCompletion]] = None
self.processing_task: Optional[asyncio.Task] = None

self.finish_stream = False
Expand Down Expand Up @@ -860,40 +861,11 @@ async def _process_stream(self): # noqa: C901
async def stream_iterator():
while not self.stream_queue.empty():
incoming_record = await self.stream_queue.get()

record_content = incoming_record.get("content", {})

streaming_choices = []
for choice in record_content.get("choices", []):
is_fim = self.proxy.context_tracking.metadata.get("is_fim", False)
if is_fim:
content = choice.get("text", "")
else:
content = choice.get("delta", {}).get("content")

if choice.get("finish_reason", None) == "stop":
for choice in incoming_record.choices:
if choice.finish_reason and \
choice.finish_reason in ["stop", "length", "content_filter"]:
self.finish_stream = True

streaming_choices.append(
StreamingChoices(
finish_reason=choice.get("finish_reason", None),
index=choice.get("index", 0),
delta=Delta(content=content, role="assistant"),
logprobs=choice.get("logprobs", None),
p=choice.get("p", None),
)
)

# Convert record to ModelResponse
mr = ModelResponse(
id=record_content.get("id", ""),
choices=streaming_choices,
created=record_content.get("created", 0),
model=record_content.get("model", ""),
object="chat.completion.chunk",
stream=True,
)
yield mr
yield incoming_record

async for record in self.output_pipeline_instance.process_stream(
stream_iterator(), cleanup_sensitive=False
Expand Down
21 changes: 13 additions & 8 deletions src/codegate/providers/copilot/streaming.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import json
from typing import List

import structlog
from pydantic import ValidationError

from src.codegate.types.openai import StreamingChatCompletion

logger = structlog.get_logger("codegate")

Expand All @@ -12,7 +15,7 @@ def __init__(self):
self.chunk_size = None # Store the original chunk size
self.size_written = False

def process_chunk(self, chunk: bytes) -> list:
def process_chunk(self, chunk: bytes) -> List[StreamingChatCompletion]:
# Skip any chunk size lines (hex number followed by \r\n)
try:
chunk_str = chunk.decode("utf-8")
Expand All @@ -24,7 +27,7 @@ def process_chunk(self, chunk: bytes) -> list:
except UnicodeDecodeError:
logger.error("Failed to decode chunk")

records = []
records: List[StreamingChatCompletion] = []
while True:
record_end = self.buffer.find("\n\n")
if record_end == -1:
Expand All @@ -36,13 +39,15 @@ def process_chunk(self, chunk: bytes) -> list:
if record.startswith("data: "):
data_content = record[6:]
if data_content.strip() == "[DONE]":
records.append({"type": "done"})
# We don't actually need to do anything with this message as the caller relies
# on the stop_reason
logger.debug("Received DONE message")
else:
try:
data = json.loads(data_content)
records.append({"type": "data", "content": data})
except json.JSONDecodeError:
logger.debug(f"Failed to parse JSON: {data_content}")
record = StreamingChatCompletion.model_validate_json(data_content)
records.append(record)
except ValidationError as e:
logger.debug(f"Failed to parse JSON: {data_content}: {e}")

return records

Expand Down