Skip to content
This repository was archived by the owner on Jun 5, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/codegate/api/v1_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pydantic

from codegate.db import models as db_models
from codegate.pipeline.base import CodeSnippet
from codegate.extract_snippets.message_extractor import CodeSnippet
from codegate.providers.base import BaseProvider
from codegate.providers.registry import ProviderRegistry

Expand Down
1 change: 1 addition & 0 deletions src/codegate/clients/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ class ClientType(Enum):
KODU = "kodu" # Kodu client
COPILOT = "copilot" # Copilot client
OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client
AIDER = "aider" # Aider client
119 changes: 119 additions & 0 deletions src/codegate/extract_snippets/body_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from abc import ABC, abstractmethod
from typing import List, Optional

from codegate.extract_snippets.message_extractor import (
AiderCodeSnippetExtractor,
ClineCodeSnippetExtractor,
CodeSnippetExtractor,
DefaultCodeSnippetExtractor,
OpenInterpreterCodeSnippetExtractor,
)


class BodyCodeSnippetExtractorError(Exception):
pass


class BodyCodeSnippetExtractor(ABC):

def __init__(self):
# Initialize the extractor in parent class. The child classes will set the extractor.
self._snippet_extractor: Optional[CodeSnippetExtractor] = None

def _extract_from_user_messages(self, data: dict) -> set[str]:
"""
The method extracts the code snippets from the user messages in the data got from the
clients.

It returns a set of filenames extracted from the code snippets.
"""
if self._snippet_extractor is None:
raise BodyCodeSnippetExtractorError("Code Extractor not set.")

filenames: List[str] = []
for msg in data.get("messages", []):
if msg.get("role", "") == "user":
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
msg.get("content")
)
filenames.extend(extracted_snippets.keys())
return set(filenames)

@abstractmethod
def extract_unique_filenames(self, data: dict) -> set[str]:
"""
Extract the unique filenames from the data received by the clients (Cline, Continue, ...)
"""
pass


class ContinueBodySnippetExtractor(BodyCodeSnippetExtractor):

def __init__(self):
self._snippet_extractor = DefaultCodeSnippetExtractor()

def extract_unique_filenames(self, data: dict) -> set[str]:
return self._extract_from_user_messages(data)


class AiderBodySnippetExtractor(BodyCodeSnippetExtractor):

def __init__(self):
self._snippet_extractor = AiderCodeSnippetExtractor()

def extract_unique_filenames(self, data: dict) -> set[str]:
return self._extract_from_user_messages(data)


class ClineBodySnippetExtractor(BodyCodeSnippetExtractor):

def __init__(self):
self._snippet_extractor = ClineCodeSnippetExtractor()

def extract_unique_filenames(self, data: dict) -> set[str]:
return self._extract_from_user_messages(data)


class OpenInterpreterBodySnippetExtractor(BodyCodeSnippetExtractor):

def __init__(self):
self._snippet_extractor = OpenInterpreterCodeSnippetExtractor()

def _is_msg_tool_call(self, msg: dict) -> bool:
return msg.get("role", "") == "assistant" and msg.get("tool_calls", [])

def _is_msg_tool_result(self, msg: dict) -> bool:
return msg.get("role", "") == "tool" and msg.get("content", "")

def _extract_args_from_tool_call(self, msg: dict) -> str:
"""
Extract the arguments from the tool call message.
"""
tool_calls = msg.get("tool_calls", [])
if not tool_calls:
return ""
return tool_calls[0].get("function", {}).get("arguments", "")

def _extract_result_from_tool_result(self, msg: dict) -> str:
"""
Extract the result from the tool result message.
"""
return msg.get("content", "")

def extract_unique_filenames(self, data: dict) -> set[str]:
messages = data.get("messages", [])
if not messages:
return set()

filenames: List[str] = []
for i_msg in range(len(messages) - 1):
msg = messages[i_msg]
next_msg = messages[i_msg + 1]
if self._is_msg_tool_call(msg) and self._is_msg_tool_result(next_msg):
tool_args = self._extract_args_from_tool_call(msg)
tool_response = self._extract_result_from_tool_result(next_msg)
extracted_snippets = self._snippet_extractor.extract_unique_snippets(
f"{tool_args}\n{tool_response}"
)
filenames.extend(extracted_snippets.keys())
return set(filenames)
41 changes: 41 additions & 0 deletions src/codegate/extract_snippets/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
from codegate.clients.clients import ClientType
from codegate.extract_snippets.body_extractor import (
AiderBodySnippetExtractor,
BodyCodeSnippetExtractor,
ClineBodySnippetExtractor,
ContinueBodySnippetExtractor,
OpenInterpreterBodySnippetExtractor,
)
from codegate.extract_snippets.message_extractor import (
AiderCodeSnippetExtractor,
ClineCodeSnippetExtractor,
CodeSnippetExtractor,
DefaultCodeSnippetExtractor,
OpenInterpreterCodeSnippetExtractor,
)


class BodyCodeExtractorFactory:

@staticmethod
def create_snippet_extractor(detected_client: ClientType) -> BodyCodeSnippetExtractor:
mapping_client_extractor = {
ClientType.GENERIC: ContinueBodySnippetExtractor(),
ClientType.CLINE: ClineBodySnippetExtractor(),
ClientType.AIDER: AiderBodySnippetExtractor(),
ClientType.OPEN_INTERPRETER: OpenInterpreterBodySnippetExtractor(),
}
return mapping_client_extractor.get(detected_client, ContinueBodySnippetExtractor())


class MessageCodeExtractorFactory:

@staticmethod
def create_snippet_extractor(detected_client: ClientType) -> CodeSnippetExtractor:
mapping_client_extractor = {
ClientType.GENERIC: DefaultCodeSnippetExtractor(),
ClientType.CLINE: ClineCodeSnippetExtractor(),
ClientType.AIDER: AiderCodeSnippetExtractor(),
ClientType.OPEN_INTERPRETER: OpenInterpreterCodeSnippetExtractor(),
}
return mapping_client_extractor.get(detected_client, DefaultCodeSnippetExtractor())
Loading