From 7fe7b6ed2de09d02ff71e60aa1c08e462335d4ee Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Tue, 4 Feb 2025 16:14:22 +0200 Subject: [PATCH 1/2] Extended the CodeSnippetExtractor functionality for muxing We used to have a generic CodeSnippetExtractor to detect code snippets and the corresponding filepaths. This PR extends it's functionality to be able to detect snippets coming from different clients. After this PR the supported and tested clients are: - Continue - Cline - Aider - Open Interpreter The main reason behind this change is to be able to catch the filepaths coming through CodeGate, detect it's destination and mux according to them. The PR also adds the necessary in the router used for muxing to detect the filepaths from the request messages. Since CodeSnippetExtractor is no longer a pipeline step it was moved out of the folder `pipeline` to it's own folder --- src/codegate/api/v1_models.py | 2 +- src/codegate/clients/clients.py | 1 + .../extract_snippets/__init__.py | 0 .../extract_snippets/body_extractor.py | 104 +++ src/codegate/extract_snippets/factory.py | 21 + .../extract_snippets/message_extractor.py | 345 ++++++++ src/codegate/muxing/router.py | 50 +- src/codegate/pipeline/base.py | 21 +- .../codegate_context_retriever/codegate.py | 7 +- src/codegate/pipeline/comment/__init__.py | 0 .../{extract_snippets => comment}/output.py | 12 +- .../extract_snippets/extract_snippets.py | 131 --- src/codegate/pipeline/factory.py | 2 +- src/codegate/pipeline/output.py | 3 +- tests/extract_snippets/test_body_extractor.py | 159 ++++ .../test_message_extractor.py | 804 ++++++++++++++++++ .../extract_snippets/test_extract_snippets.py | 186 ---- 17 files changed, 1494 insertions(+), 354 deletions(-) rename src/codegate/{pipeline => }/extract_snippets/__init__.py (100%) create mode 100644 src/codegate/extract_snippets/body_extractor.py create mode 100644 src/codegate/extract_snippets/factory.py create mode 100644 src/codegate/extract_snippets/message_extractor.py create mode 100644 src/codegate/pipeline/comment/__init__.py rename src/codegate/pipeline/{extract_snippets => comment}/output.py (94%) delete mode 100644 src/codegate/pipeline/extract_snippets/extract_snippets.py create mode 100644 tests/extract_snippets/test_body_extractor.py create mode 100644 tests/extract_snippets/test_message_extractor.py delete mode 100644 tests/pipeline/extract_snippets/test_extract_snippets.py diff --git a/src/codegate/api/v1_models.py b/src/codegate/api/v1_models.py index 88da7113..d0495c4a 100644 --- a/src/codegate/api/v1_models.py +++ b/src/codegate/api/v1_models.py @@ -5,7 +5,7 @@ import pydantic from codegate.db import models as db_models -from codegate.pipeline.base import CodeSnippet +from codegate.extract_snippets.message_extractor import CodeSnippet from codegate.providers.base import BaseProvider from codegate.providers.registry import ProviderRegistry diff --git a/src/codegate/clients/clients.py b/src/codegate/clients/clients.py index 0c1daf2e..840a5729 100644 --- a/src/codegate/clients/clients.py +++ b/src/codegate/clients/clients.py @@ -11,3 +11,4 @@ class ClientType(Enum): KODU = "kodu" # Kodu client COPILOT = "copilot" # Copilot client OPEN_INTERPRETER = "open_interpreter" # Open Interpreter client + AIDER = "aider" # Aider client diff --git a/src/codegate/pipeline/extract_snippets/__init__.py b/src/codegate/extract_snippets/__init__.py similarity index 100% rename from src/codegate/pipeline/extract_snippets/__init__.py rename to src/codegate/extract_snippets/__init__.py diff --git a/src/codegate/extract_snippets/body_extractor.py b/src/codegate/extract_snippets/body_extractor.py new file mode 100644 index 00000000..e885017a --- /dev/null +++ b/src/codegate/extract_snippets/body_extractor.py @@ -0,0 +1,104 @@ +from abc import ABC, abstractmethod +from typing import List + +from codegate.extract_snippets.message_extractor import ( + AiderCodeSnippetExtractor, + ClineCodeSnippetExtractor, + CodeSnippetExtractor, + DefaultCodeSnippetExtractor, + OpenInterpreterCodeSnippetExtractor, +) + + +class BodyCodeSnippetExtractor(ABC): + + def __init__(self): + # Initialize the extractor in parent class. The child classes will set the extractor. + self._snippet_extractor: CodeSnippetExtractor = None + + def _extract_from_user_messages(self, data: dict) -> set[str]: + copied_data = data.copy() + filenames: List[str] = [] + for msg in copied_data.get("messages", []): + if msg.get("role", "") == "user": + extracted_snippets = self._snippet_extractor.extract_unique_snippets( + msg.get("content") + ) + filenames.extend(extracted_snippets.keys()) + return set(filenames) + + @abstractmethod + def extract_unique_snippets(self, data: dict) -> set[str]: + pass + + +class ContinueBodySnippetExtractor(BodyCodeSnippetExtractor): + + def __init__(self): + self._snippet_extractor = DefaultCodeSnippetExtractor() + + def extract_unique_snippets(self, data: dict) -> set[str]: + return self._extract_from_user_messages(data) + + +class AiderBodySnippetExtractor(BodyCodeSnippetExtractor): + + def __init__(self): + self._snippet_extractor = AiderCodeSnippetExtractor() + + def extract_unique_snippets(self, data: dict) -> set[str]: + return self._extract_from_user_messages(data) + + +class ClineBodySnippetExtractor(BodyCodeSnippetExtractor): + + def __init__(self): + self._snippet_extractor = ClineCodeSnippetExtractor() + + def extract_unique_snippets(self, data: dict) -> set[str]: + return self._extract_from_user_messages(data) + + +class OpenInterpreterBodySnippetExtractor(BodyCodeSnippetExtractor): + + def __init__(self): + self._snippet_extractor = OpenInterpreterCodeSnippetExtractor() + + def _is_msg_tool_call(self, msg: dict) -> bool: + return msg.get("role", "") == "assistant" and msg.get("tool_calls", []) + + def _is_msg_tool_result(self, msg: dict) -> bool: + return msg.get("role", "") == "tool" and msg.get("content", "") + + def _extract_args_from_tool_call(self, msg: dict) -> str: + """ + Extract the arguments from the tool call message. + """ + tool_calls = msg.get("tool_calls", []) + if not tool_calls: + return "" + return tool_calls[0].get("function", {}).get("arguments", "") + + def _extract_result_from_tool_result(self, msg: dict) -> str: + """ + Extract the result from the tool result message. + """ + return msg.get("content", "") + + def extract_unique_snippets(self, data: dict) -> set[str]: + messages = data.get("messages", []) + if not messages: + return set() + + filenames: List[str] = [] + for i_msg in range(len(messages) - 1): + msg = messages[i_msg] + next_msg = messages[i_msg + 1] + if self._is_msg_tool_call(msg) and self._is_msg_tool_result(next_msg): + tool_args = self._extract_args_from_tool_call(msg) + tool_response = self._extract_result_from_tool_result(next_msg) + extracted_snippets = self._snippet_extractor.extract_unique_snippets( + f"{tool_args}\n{tool_response}" + ) + filenames.extend(extracted_snippets.keys()) + return set(filenames) diff --git a/src/codegate/extract_snippets/factory.py b/src/codegate/extract_snippets/factory.py new file mode 100644 index 00000000..3c14a4a9 --- /dev/null +++ b/src/codegate/extract_snippets/factory.py @@ -0,0 +1,21 @@ +from codegate.clients.clients import ClientType +from codegate.extract_snippets.body_extractor import ( + AiderBodySnippetExtractor, + BodyCodeSnippetExtractor, + ClineBodySnippetExtractor, + ContinueBodySnippetExtractor, + OpenInterpreterBodySnippetExtractor, +) + + +class CodeSnippetExtractorFactory: + + @staticmethod + def create_snippet_extractor(detected_client: ClientType) -> BodyCodeSnippetExtractor: + mapping_client_extractor = { + ClientType.GENERIC: ContinueBodySnippetExtractor(), + ClientType.CLINE: ClineBodySnippetExtractor(), + ClientType.AIDER: AiderBodySnippetExtractor(), + ClientType.OPEN_INTERPRETER: OpenInterpreterBodySnippetExtractor(), + } + return mapping_client_extractor.get(detected_client, ContinueBodySnippetExtractor()) diff --git a/src/codegate/extract_snippets/message_extractor.py b/src/codegate/extract_snippets/message_extractor.py new file mode 100644 index 00000000..e9a7c968 --- /dev/null +++ b/src/codegate/extract_snippets/message_extractor.py @@ -0,0 +1,345 @@ +import re +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Dict, List, Optional, Self + +import structlog +from pydantic import BaseModel, field_validator, model_validator +from pygments.lexers import guess_lexer + +logger = structlog.get_logger("codegate") + +CODE_BLOCK_PATTERN = re.compile( + r"```" # Opening backticks, no whitespace after backticks and before language + r"(?:(?P[a-zA-Z0-9_+-]+)\s+)?" # Language must be followed by whitespace if present + r"(?:(?P[^\s\(\n]+))?" # Optional filename (cannot contain spaces or parentheses) + r"(?:\s+\([0-9]+-[0-9]+\))?" # Optional line numbers in parentheses + r"\s*\n" # Required newline after metadata + r"(?P.*?)" # Content (non-greedy match) + r"```", # Closing backticks + re.DOTALL, +) + +CODE_BLOCK_WITH_FILENAME_PATTERN = re.compile( + r"```" # Opening backticks, no whitespace after backticks and before language + r"(?:(?P[a-zA-Z0-9_+-]+)\s+)?" # Language must be followed by whitespace if present + r"(?P[^\s\(\n]+)" # Mandatory filename (cannot contain spaces or parentheses) + r"(?:\s+\([0-9]+-[0-9]+\))?" # Optional line numbers in parentheses + r"\s*\n" # Required newline after metadata + r"(?P.*?)" # Content (non-greedy match) + r"```", # Closing backticks + re.DOTALL, +) + +CLINE_FILE_CONTENT_PATTERN = re.compile( + r"[^\"]+)\">" # Match the opening tag with mandatory file + r"(?P.*?)" # Match the content (non-greedy) + r"", # Match the closing tag + re.DOTALL, +) + +AIDER_SUMMARIES_CONTENT_PATTERN = re.compile( + r"^(?P[^\n]+):\n" # Match the filepath as the header + r"(?P.*?)" # Match the content (non-greedy) + r"⋮...\n\n", # Match the ending pattern with dots + re.DOTALL | re.MULTILINE, +) + +AIDER_FILE_CONTENT_PATTERN = re.compile( + r"^(?P[^\n]+)\n" # Match the filepath as the header + r"```" # Match the opening triple backticks + r"(?P.*?)" # Match the content (non-greedy) + r"```", # Match the closing triple backticks + re.DOTALL | re.MULTILINE, +) + +OPEN_INTERPRETER_CONTENT_PATTERN = re.compile( + r"# Attempting to read the content of `(?P[^`]+)`" # Match the filename backticks + r".*?" # Match any characters non-greedily + r"File read successfully\.\n" # Match the "File read successfully." text + r"'(?P.*?)'", # Match the content wrapped in single quotes + re.DOTALL, +) + +OPEN_INTERPRETER_Y_CONTENT_PATTERN = re.compile( + r"# Open and read the contents of the (?P[^\s]+) file" # Match the filename + r".*?" # Match any characters non-greedily + r"\n\n" # Match the double line break + r"(?P.*)", # Match everything that comes after the double line break + re.DOTALL, +) + + +class MatchedPatternSnippet(BaseModel): + """ + Represents a match from the code snippet patterns. + Meant to be used by all CodeSnippetExtractors. + """ + + language: Optional[str] + filename: Optional[str] + content: str + + +class CodeSnippet(BaseModel): + """ + Represents a code snippet with its programming language. + + Args: + language: The programming language identifier (e.g., 'python', 'javascript') + code: The actual code content + """ + + code: str + language: Optional[str] + filepath: Optional[str] + libraries: List[str] = [] + file_extension: Optional[str] = None + + @field_validator("language", mode="after") + @classmethod + def ensure_lowercase(cls, value: str) -> str: + if value is not None: + value = value.strip().lower() + return value + + @model_validator(mode="after") + def fill_file_extension(self) -> Self: + if self.filepath is not None: + self.file_extension = Path(self.filepath).suffix + return self + + +class CodeSnippetExtractor(ABC): + + def __init__(self): + self._extension_mapping = { + ".py": "python", + ".js": "javascript", + ".ts": "typescript", + ".tsx": "typescript", + ".go": "go", + ".rs": "rust", + ".java": "java", + } + self._language_mapping = { + "py": "python", + "js": "javascript", + "ts": "typescript", + "tsx": "typescript", + "go": "go", + "rs": "rust", + "java": "java", + } + self._available_languages = ["python", "javascript", "typescript", "go", "rust", "java"] + + @property + @abstractmethod + def codeblock_pattern(self) -> List[re.Pattern]: + """ + List of regex patterns to match code blocks without filenames. + """ + pass + + @property + @abstractmethod + def codeblock_with_filename_pattern(self) -> List[re.Pattern]: + """ + List of regex patterns to match code blocks with filenames. + """ + pass + + @abstractmethod + def _get_match_pattern_snippet(self, match: re.Match) -> MatchedPatternSnippet: + pass + + def _choose_regex(self, require_filepath: bool) -> List[re.Pattern]: + if require_filepath: + return self.codeblock_with_filename_pattern + else: + return self.codeblock_pattern + + def _ecosystem_from_filepath(self, filepath: str): + """ + Determine language from filepath. + + Args: + filepath: Path to the file + + Returns: + Determined language based on file extension + """ + + # Get the file extension + path_filename = Path(filepath) + file_extension = path_filename.suffix.lower() + return self._extension_mapping.get(file_extension, None) + + def _ecosystem_from_message(self, message: str): + """ + Determine language from message. + + Args: + message: The language from the message. Some extensions send a different + format where the language is present in the snippet, + e.g. "py /path/to/file (lineFrom-lineTo)" + + Returns: + Determined language based on message content + """ + return self._language_mapping.get(message, None) + + def _get_snippet_for_match(self, match: re.Match) -> CodeSnippet: + matched_snippet = self._get_match_pattern_snippet(match) + + # If we have a single word without extension after the backticks, + # it's a language identifier, not a filename. Typicaly used in the + # format ` ```python ` in output snippets + if ( + matched_snippet.filename + and not matched_snippet.language + and "." not in matched_snippet.filename + ): + lang = matched_snippet.filename + if lang not in self._available_languages: + # try to get it from the extension + lang = self._ecosystem_from_message(matched_snippet.filename) + if lang not in self._available_languages: + lang = None + matched_snippet.filename = None + else: + # Determine language from the message, either by the short + # language identifier or by the filename + lang = None + if matched_snippet.language: + lang = self._ecosystem_from_message(matched_snippet.language.strip()) + if lang is None and matched_snippet.filename: + matched_snippet.filename = matched_snippet.filename.strip() + # Determine language from the filename + lang = self._ecosystem_from_filepath(matched_snippet.filename) + if lang is None: + # try to guess it from the code + lexer = guess_lexer(matched_snippet.content) + if lexer and lexer.name: + lang = lexer.name.lower() + # only add available languages + if lang not in self._available_languages: + lang = None + + # just correct the typescript exception + lang_map = {"typescript": "javascript"} + if lang: + lang = lang_map.get(lang, lang) + return CodeSnippet( + filepath=matched_snippet.filename, code=matched_snippet.content, language=lang + ) + + def extract_snippets(self, message: str, require_filepath: bool = False) -> List[CodeSnippet]: + """ + Extract code snippets from a message. + + Args: + message: Input text containing code snippets + + Returns: + List of extracted code snippets + """ + regexes = self._choose_regex(require_filepath) + # Find all code block matches + return [ + self._get_snippet_for_match(match) + for regex in regexes + for match in regex.finditer(message) + ] + + def extract_unique_snippets(self, message: str) -> Dict[str, CodeSnippet]: + """ + Extract unique filpaths from a message. Uses the filepath as key. + + Args: + message: Input text containing code snippets + + Returns: + Dictionary of unique code snippets with the filepath as key + """ + regexes = self._choose_regex(require_filepath=True) + unique_snippets: Dict[str, CodeSnippet] = {} + for regex in regexes: + for match in regex.finditer(message): + snippet = self._get_snippet_for_match(match) + filename = Path(snippet.filepath).name if snippet.filepath else None + if filename and filename not in unique_snippets: + unique_snippets[filename] = snippet + + return unique_snippets + + +class DefaultCodeSnippetExtractor(CodeSnippetExtractor): + + @property + def codeblock_pattern(self) -> re.Pattern: + return [CODE_BLOCK_PATTERN] + + @property + def codeblock_with_filename_pattern(self) -> re.Pattern: + return [CODE_BLOCK_WITH_FILENAME_PATTERN] + + def _get_match_pattern_snippet(self, match: re.Match) -> MatchedPatternSnippet: + matched_language = match.group("language") if match.group("language") else None + filename = match.group("filename") if match.group("filename") else None + content = match.group("content") + return MatchedPatternSnippet(language=matched_language, filename=filename, content=content) + + +class ClineCodeSnippetExtractor(CodeSnippetExtractor): + + @property + def codeblock_pattern(self) -> re.Pattern: + return [CLINE_FILE_CONTENT_PATTERN] + + @property + def codeblock_with_filename_pattern(self) -> re.Pattern: + return [CLINE_FILE_CONTENT_PATTERN] + + def _get_match_pattern_snippet(self, match: re.Match) -> MatchedPatternSnippet: + # We don't have language in the cline pattern + matched_language = None + filename = match.group("filename") + content = match.group("content") + return MatchedPatternSnippet(language=matched_language, filename=filename, content=content) + + +class AiderCodeSnippetExtractor(CodeSnippetExtractor): + + @property + def codeblock_pattern(self) -> re.Pattern: + return [AIDER_SUMMARIES_CONTENT_PATTERN, AIDER_FILE_CONTENT_PATTERN] + + @property + def codeblock_with_filename_pattern(self) -> re.Pattern: + return [AIDER_SUMMARIES_CONTENT_PATTERN, AIDER_FILE_CONTENT_PATTERN] + + def _get_match_pattern_snippet(self, match: re.Match) -> MatchedPatternSnippet: + # We don't have language in the cline pattern + matched_language = None + filename = match.group("filename") + content = match.group("content") + return MatchedPatternSnippet(language=matched_language, filename=filename, content=content) + + +class OpenInterpreterCodeSnippetExtractor(CodeSnippetExtractor): + + @property + def codeblock_pattern(self) -> re.Pattern: + return [OPEN_INTERPRETER_CONTENT_PATTERN, OPEN_INTERPRETER_Y_CONTENT_PATTERN] + + @property + def codeblock_with_filename_pattern(self) -> re.Pattern: + return [OPEN_INTERPRETER_CONTENT_PATTERN, OPEN_INTERPRETER_Y_CONTENT_PATTERN] + + def _get_match_pattern_snippet(self, match: re.Match) -> MatchedPatternSnippet: + # We don't have language in the cline pattern + matched_language = None + filename = match.group("filename") + content = match.group("content") + return MatchedPatternSnippet(language=matched_language, filename=filename, content=content) diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index 812a6c61..b75a81bf 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -3,7 +3,9 @@ import structlog from fastapi import APIRouter, HTTPException, Request +from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient +from codegate.extract_snippets.factory import CodeSnippetExtractorFactory from codegate.muxing import rulematcher from codegate.muxing.adapter import BodyAdapter, ResponseAdapter from codegate.providers.registry import ProviderRegistry @@ -36,6 +38,37 @@ def get_routes(self) -> APIRouter: def _ensure_path_starts_with_slash(self, path: str) -> str: return path if path.startswith("/") else f"/{path}" + def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> set[str]: + """ + Extract filenames from the request data. + """ + body_extractor = CodeSnippetExtractorFactory.create_snippet_extractor(detected_client) + return body_extractor.extract_unique_snippets(data) + + async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]: + """ + Get the model routes for the given filenames. + """ + model_routes = [] + mux_registry = await rulematcher.get_muxing_rules_registry() + try: + # Try to get a catch_all route + single_model_route = await mux_registry.get_match_for_active_workspace( + thing_to_match=None + ) + model_routes.append(single_model_route) + + # Get the model routes for each filename + for filename in filenames: + model_route = await mux_registry.get_match_for_active_workspace( + thing_to_match=filename + ) + model_routes.append(model_route) + except Exception as e: + logger.error(f"Error getting active workspace muxes: {e}") + raise HTTPException(str(e), status_code=404) + return model_routes + def _setup_routes(self): @self.router.post(f"/{self.route_name}/{{rest_of_path:path}}") @@ -56,18 +89,17 @@ async def route_to_dest_provider( body = await request.body() data = json.loads(body) - mux_registry = await rulematcher.get_muxing_rules_registry() - try: - # TODO: For future releases we will have to idenify a thing_to_match - # and use our registry to get the correct muxes for the active workspace - model_route = await mux_registry.get_match_for_active_workspace(thing_to_match=None) - except Exception as e: - logger.error(f"Error getting active workspace muxes: {e}") - raise HTTPException(str(e)) + filenames_in_data = self._extract_request_filenames(request.state.detected_client, data) + logger.info(f"Extracted filenames from request: {filenames_in_data}") - if not model_route: + model_routes = await self._get_model_routes(filenames_in_data) + if not model_routes: raise HTTPException("No rule found for the active workspace", status_code=404) + # We still need some logic here to handle the case where we have multiple model routes. + # For the moment since we match all only pick the first. + model_route = model_routes[0] + # Parse the input data and map it to the destination provider format rest_of_path = self._ensure_path_starts_with_slash(rest_of_path) new_data = self._body_adapter.map_body_to_dest(model_route, data) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index bb77c233..5b8b87eb 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -13,31 +13,12 @@ from codegate.clients.clients import ClientType from codegate.db.models import Alert, Output, Prompt +from codegate.extract_snippets.message_extractor import CodeSnippet from codegate.pipeline.secrets.manager import SecretsManager logger = structlog.get_logger("codegate") -@dataclass -class CodeSnippet: - """ - Represents a code snippet with its programming language. - - Args: - language: The programming language identifier (e.g., 'python', 'javascript') - code: The actual code content - """ - - code: str - language: Optional[str] - filepath: Optional[str] - libraries: List[str] = field(default_factory=list) - - def __post_init__(self): - if self.language is not None: - self.language = self.language.strip().lower() - - class AlertSeverity(Enum): INFO = "info" CRITICAL = "critical" diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 159b0a92..15a5b122 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,13 +5,13 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType +from codegate.extract_snippets.message_extractor import DefaultCodeSnippetExtractor from codegate.pipeline.base import ( AlertSeverity, PipelineContext, PipelineResult, PipelineStep, ) -from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets from codegate.storage.storage_engine import StorageEngine from codegate.utils.package_extractor import PackageExtractor from codegate.utils.utils import generate_vector_string @@ -25,6 +25,9 @@ class CodegateContextRetriever(PipelineStep): the word "codegate" in the user message. """ + def __init__(self): + self.extractor = DefaultCodeSnippetExtractor() + @property def name(self) -> str: """ @@ -70,7 +73,7 @@ async def process( # noqa: C901 storage_engine = StorageEngine() # Extract any code snippets - snippets = extract_snippets(user_message) + snippets = self.extractor.extract_snippets(user_message) bad_snippet_packages = [] if len(snippets) > 0: diff --git a/src/codegate/pipeline/comment/__init__.py b/src/codegate/pipeline/comment/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/codegate/pipeline/extract_snippets/output.py b/src/codegate/pipeline/comment/output.py similarity index 94% rename from src/codegate/pipeline/extract_snippets/output.py rename to src/codegate/pipeline/comment/output.py index 53b3af24..6210ee13 100644 --- a/src/codegate/pipeline/extract_snippets/output.py +++ b/src/codegate/pipeline/comment/output.py @@ -5,8 +5,11 @@ from litellm import ModelResponse from litellm.types.utils import Delta, StreamingChoices -from codegate.pipeline.base import AlertSeverity, CodeSnippet, PipelineContext -from codegate.pipeline.extract_snippets.extract_snippets import extract_snippets +from codegate.extract_snippets.message_extractor import ( + CodeSnippet, + DefaultCodeSnippetExtractor, +) +from codegate.pipeline.base import AlertSeverity, PipelineContext from codegate.pipeline.output import OutputPipelineContext, OutputPipelineStep from codegate.storage import StorageEngine from codegate.utils.package_extractor import PackageExtractor @@ -17,6 +20,9 @@ class CodeCommentStep(OutputPipelineStep): """Pipeline step that adds comments after code blocks""" + def __init__(self): + self.extractor = DefaultCodeSnippetExtractor() + @property def name(self) -> str: return "code-comment" @@ -118,7 +124,7 @@ async def process_chunk( current_content = "".join(context.processed_content + [chunk.choices[0].delta.content]) # Extract snippets from current content - snippets = extract_snippets(current_content) + snippets = self.extractor.extract_snippets(current_content) # Check if a new snippet has been completed if len(snippets) > len(context.snippets): diff --git a/src/codegate/pipeline/extract_snippets/extract_snippets.py b/src/codegate/pipeline/extract_snippets/extract_snippets.py deleted file mode 100644 index 5d95abc9..00000000 --- a/src/codegate/pipeline/extract_snippets/extract_snippets.py +++ /dev/null @@ -1,131 +0,0 @@ -import os -import re -from typing import List, Optional - -import structlog -from pygments.lexers import guess_lexer - -from codegate.pipeline.base import CodeSnippet - -CODE_BLOCK_PATTERN = re.compile( - r"```" # Opening backticks, no whitespace after backticks and before language - r"(?:(?P[a-zA-Z0-9_+-]+)\s+)?" # Language must be followed by whitespace if present - r"(?:(?P[^\s\(\n]+))?" # Optional filename (cannot contain spaces or parentheses) - r"(?:\s+\([0-9]+-[0-9]+\))?" # Optional line numbers in parentheses - r"\s*\n" # Required newline after metadata - r"(?P.*?)" # Content (non-greedy match) - r"```", # Closing backticks - re.DOTALL, -) - -logger = structlog.get_logger("codegate") - - -def ecosystem_from_filepath(filepath: str) -> Optional[str]: - """ - Determine language from filepath. - - Args: - filepath: Path to the file - - Returns: - Determined language based on file extension - """ - # Implement file extension to language mapping - extension_mapping = { - ".py": "python", - ".js": "javascript", - ".ts": "typescript", - ".tsx": "typescript", - ".go": "go", - ".rs": "rust", - ".java": "java", - } - - # Get the file extension - ext = os.path.splitext(filepath)[1].lower() - return extension_mapping.get(ext, None) - - -def ecosystem_from_message(message: str) -> Optional[str]: - """ - Determine language from message. - - Args: - message: The language from the message. Some extensions send a different - format where the language is present in the snippet, - e.g. "py /path/to/file (lineFrom-lineTo)" - - Returns: - Determined language based on message content - """ - language_mapping = { - "py": "python", - "js": "javascript", - "ts": "typescript", - "tsx": "typescript", - "go": "go", - "rs": "rust", - "java": "java", - } - return language_mapping.get(message, None) - - -def extract_snippets(message: str) -> List[CodeSnippet]: - """ - Extract code snippets from a message. - - Args: - message: Input text containing code snippets - - Returns: - List of extracted code snippets - """ - # Regular expression to find code blocks - - snippets: List[CodeSnippet] = [] - available_languages = ["python", "javascript", "typescript", "go", "rust", "java"] - - # Find all code block matches - for match in CODE_BLOCK_PATTERN.finditer(message): - matched_language = match.group("language") if match.group("language") else None - filename = match.group("filename") if match.group("filename") else None - content = match.group("content") - - # If we have a single word without extension after the backticks, - # it's a language identifier, not a filename. Typicaly used in the - # format ` ```python ` in output snippets - if filename and not matched_language and "." not in filename: - lang = filename - if lang not in available_languages: - # try to get it from the extension - lang = ecosystem_from_message(filename) - if lang not in available_languages: - lang = None - filename = None - else: - # Determine language from the message, either by the short - # language identifier or by the filename - lang = None - if matched_language: - lang = ecosystem_from_message(matched_language.strip()) - if lang is None and filename: - filename = filename.strip() - # Determine language from the filename - lang = ecosystem_from_filepath(filename) - if lang is None: - # try to guess it from the code - lexer = guess_lexer(content) - if lexer and lexer.name: - lang = lexer.name.lower() - # only add available languages - if lang not in available_languages: - lang = None - - #  just correct the typescript exception - lang_map = {"typescript": "javascript"} - if lang: - lang = lang_map.get(lang, lang) - snippets.append(CodeSnippet(filepath=filename, code=content, language=lang)) - - return snippets diff --git a/src/codegate/pipeline/factory.py b/src/codegate/pipeline/factory.py index e51c7691..0fdd66c4 100644 --- a/src/codegate/pipeline/factory.py +++ b/src/codegate/pipeline/factory.py @@ -5,7 +5,7 @@ from codegate.pipeline.base import PipelineStep, SequentialPipelineProcessor from codegate.pipeline.cli.cli import CodegateCli from codegate.pipeline.codegate_context_retriever.codegate import CodegateContextRetriever -from codegate.pipeline.extract_snippets.output import CodeCommentStep +from codegate.pipeline.comment.output import CodeCommentStep from codegate.pipeline.output import OutputPipelineProcessor, OutputPipelineStep from codegate.pipeline.secrets.manager import SecretsManager from codegate.pipeline.secrets.secrets import ( diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 76895120..6f990e03 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -8,7 +8,8 @@ from litellm.types.utils import Delta, StreamingChoices from codegate.db.connection import DbRecorder -from codegate.pipeline.base import CodeSnippet, PipelineContext +from codegate.extract_snippets.message_extractor import CodeSnippet +from codegate.pipeline.base import PipelineContext logger = structlog.get_logger("codegate") diff --git a/tests/extract_snippets/test_body_extractor.py b/tests/extract_snippets/test_body_extractor.py new file mode 100644 index 00000000..ab4aaf4e --- /dev/null +++ b/tests/extract_snippets/test_body_extractor.py @@ -0,0 +1,159 @@ +from typing import Dict, List, NamedTuple + +import pytest + +from codegate.extract_snippets.body_extractor import ( + ClineBodySnippetExtractor, + OpenInterpreterBodySnippetExtractor, +) + + +class BodyCodeSnippetTest(NamedTuple): + input_body_dict: Dict[str, List[Dict[str, str]]] + expected_count: int + expected: List[str] + + +def _evaluate_actual_filenames(filenames: set[str], test_case: BodyCodeSnippetTest): + assert len(filenames) == test_case.expected_count + assert filenames == set(test_case.expected) + + +@pytest.mark.parametrize( + "test_case", + [ + # Analyze processed snippets from OpenInterpreter + BodyCodeSnippetTest( + input_body_dict={ + "messages": [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "toolu_4", + "type": "function", + "function": { + "name": "execute", + "arguments": ( + '{"language": "python", "code": "\\n' + "# Open and read the contents of the src/codegate/api/v1.py" + " file\\n" + "with open('src/codegate/api/v1.py', 'r') as file:\\n " + 'content = file.read()\\n\\ncontent\\n"}' + ), + }, + } + ], + }, + { + "role": "tool", + "name": "execute", + "content": ( + "Output truncated.\n\nr as e:\\n " + 'raise HTTPException(status_code=400",' + ), + "tool_call_id": "toolu_4", + }, + ] + }, + expected_count=1, + expected=["v1.py"], + ), + ], +) +def test_body_extract_openinterpreter_snippets(test_case: BodyCodeSnippetTest): + extractor = OpenInterpreterBodySnippetExtractor() + filenames = extractor.extract_unique_snippets(test_case.input_body_dict) + _evaluate_actual_filenames(filenames, test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # Analyze processed snippets from OpenInterpreter + BodyCodeSnippetTest( + input_body_dict={ + "messages": [ + { + "role": "user", + "content": ''' + [ +now please analyze the folder 'codegate/src/codegate/api/' (see below for folder content) + + + +├── __init__.py +├── __pycache__/ +├── v1.py +├── v1_models.py +└── v1_processing.py + + + + + + +from typing import List, Optional +from uuid import UUID + +import requests +import structlog + +v1 = APIRouter() +wscrud = crud.WorkspaceCrud() +pcrud = provendcrud.ProviderCrud() + + + + +import datetime +from enum import Enum + + +class Conversation(pydantic.BaseModel): + """ + Represents a conversation. + """ + + question_answers: List[QuestionAnswer] + provider: Optional[str] + type: QuestionType + chat_id: str + conversation_timestamp: datetime.datetime + token_usage_agg: Optional[TokenUsageAggregate] + + + + +import asyncio +import json +import re +from collections import defaultdict + +async def _process_prompt_output_to_partial_qa( + prompts_outputs: List[GetPromptWithOutputsRow], +) -> List[PartialQuestionAnswer]: + """ + Process the prompts and outputs to PartialQuestionAnswer objects. + """ + # Parse the prompts and outputs in parallel + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(_get_partial_question_answer(row)) for row in prompts_outputs] + return [task.result() for task in tasks if task.result() is not None] + + + + ''', + }, + ] + }, + expected_count=4, + expected=["__init__.py", "v1.py", "v1_models.py", "v1_processing.py"], + ), + ], +) +def test_body_extract_cline_snippets(test_case: BodyCodeSnippetTest): + extractor = ClineBodySnippetExtractor() + filenames = extractor.extract_unique_snippets(test_case.input_body_dict) + _evaluate_actual_filenames(filenames, test_case) diff --git a/tests/extract_snippets/test_message_extractor.py b/tests/extract_snippets/test_message_extractor.py new file mode 100644 index 00000000..07e4d8b3 --- /dev/null +++ b/tests/extract_snippets/test_message_extractor.py @@ -0,0 +1,804 @@ +from typing import List, NamedTuple + +import pytest + +from codegate.extract_snippets.message_extractor import ( + AiderCodeSnippetExtractor, + ClineCodeSnippetExtractor, + CodeSnippet, + DefaultCodeSnippetExtractor, + OpenInterpreterCodeSnippetExtractor, +) + + +class CodeSnippetTest(NamedTuple): + input_message: str + expected_count: int + expected: List[CodeSnippet] + + +def _evaluate_actual_snippets(actual_snips: List[CodeSnippet], expected_snips: CodeSnippetTest): + assert len(actual_snips) == expected_snips.expected_count + + for expected, actual in zip(expected_snips.expected, actual_snips): + assert actual.language == expected.language + assert actual.filepath == expected.filepath + assert actual.file_extension == expected.file_extension + assert expected.code in actual.code + + +@pytest.mark.parametrize( + "test_case", + [ + # Single Python code block without filename + CodeSnippetTest( + input_message=""": + Here's a Python snippet: + ``` + def hello(): + print("Hello, world!") + ``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language=None, filepath=None, code='print("Hello, world!")', file_extension=None + ), + ], + ), + # output code snippet with no filename + CodeSnippetTest( + input_message=""": + ```python + @app.route('/') + def hello(): + GITHUB_TOKEN="ghp_RjzIRljYij9CznoS7QAnD5RaFF6yH32073uI" + if __name__ == '__main__': + app.run() + return "Hello, Moon!" + ``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath=None, + code="Hello, Moon!", + file_extension=None, + ), + ], + ), + # Single Python code block + CodeSnippetTest( + input_message=""" + Here's a Python snippet: + ```hello_world.py (8-13) + def hello(): + print("Hello, world!") + ``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="hello_world.py", + code='print("Hello, world!")', + file_extension=".py", + ), + ], + ), + # Single Python code block with a language identifier + CodeSnippetTest( + input_message=""" + Here's a Python snippet: + ```py goodbye_world.py (8-13) + def hello(): + print("Goodbye, world!") + ``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="goodbye_world.py", + code='print("Goodbye, world!")', + file_extension=".py", + ), + ], + ), + # Multiple code blocks with different languages + CodeSnippetTest( + input_message=""" + Python snippet: + ```main.py + def hello(): + print("Hello") + ``` + + JavaScript snippet: + ```script.js (1-3) + function greet() { + console.log("Hi"); + } + ``` + """, + expected_count=2, + expected=[ + CodeSnippet(language="python", filepath="main.py", code='print("Hello")'), + CodeSnippet( + language="javascript", + filepath="script.js", + code='console.log("Hi");', + file_extension=".js", + ), + ], + ), + # No code blocks + CodeSnippetTest( + input_message="Just a plain text message", + expected_count=0, + expected=[], + ), + # unknown language + CodeSnippetTest( + input_message=""": + Here's a Perl snippet: + ```hello_world.pl + I'm a Perl script + ``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language=None, + filepath="hello_world.pl", + code="I'm a Perl script", + file_extension=".pl", + ), + ], + ), + ], +) +def test_extract_snippets(test_case: CodeSnippetTest): + extractor = DefaultCodeSnippetExtractor() + snippets = extractor.extract_snippets(test_case.input_message) + _evaluate_actual_snippets(snippets, test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # Single snippet from Continue + CodeSnippetTest( + input_message=""" + + + ```py testing_file.py (1-17) + import invokehttp + import fastapi + from fastapi import FastAPI, Request, Response, HTTPException + import numpy + + GITHUB_TOKEN="ghp_1J9Z3Z2dfg4dfs23dsfsdf232aadfasdfasfasdf32" + + def add(a, b): + return a + b + + def multiply(a, b): + return a * b + + + + def substract(a, b): + + ``` + analyze this file + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="testing_file.py", + code="def multiply(a, b):", + file_extension=".py", + ), + ], + ), + # Two snippets from Continue, one inserting with CTRL+L and another one with @ + CodeSnippetTest( + input_message=''' +```/Users/user/StacklokRepos/codegate/tests/pipeline/extract_snippets/test_extract_snippets.py +from typing import List, NamedTuple + +import pytest + +from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippet, CodeSnippetExtractor + + +class CodeSnippetTest(NamedTuple): + input_message: str + expected_count: int + expected: List[CodeSnippet] + +@pytest.mark.parametrize( + "filepath", + [ + # No extension + "README", + "script", + "README.txt", + # Unknown extensions + "file.xyz", + "unknown.extension", + ], +) +def test_no_or_unknown_extensions(filepath): + extractor = CodeSnippetExtractor() + assert extractor._ecosystem_from_filepath(filepath) is None + +``` + + + +```py codegate/src/codegate/pipeline/extract_snippets/extract_snippets.py (24-50) +class CodeSnippet(BaseModel): + """ + Represents a code snippet with its programming language. + + Args: + language: The programming language identifier (e.g., 'python', 'javascript') + code: The actual code content + """ + + code: str + language: Optional[str] + filepath: Optional[str] + libraries: List[str] = [] + file_extension: Optional[str] = None +``` +analyze this file with respect to test_extract_snippets.py + ''', + expected_count=2, + expected=[ + CodeSnippet( + language="python", + filepath="/Users/user/StacklokRepos/codegate/tests/pipeline/extract_snippets/test_extract_snippets.py", + code="def test_no_or_unknown_extensions(filepath):", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="codegate/src/codegate/pipeline/extract_snippets/extract_snippets.py", + code="class CodeSnippet(BaseModel)", + file_extension=".py", + ), + ], + ), + # Two snippets from Continue, one inserting with CTRL+L and another one with @ + CodeSnippetTest( + input_message=""" +```/Users/foo_user/StacklokRepos/src/README.MD +# Handling changes + +Changes are not immediate + +### Example + +On the response class "Package", changing "description" "repo_description": + +From: +``` +@dataclass +class Package: + id: str + name: str +``` + +To: +``` +@dataclass +class Package: + id: str + name: str +``` + +And finally + +``` +@dataclass +class Package: + id: str + name: str + type: str +``` +``` + +README.MD and that file? + """, + expected_count=1, + expected=[ + CodeSnippet( + language=None, + filepath="/Users/foo_user/StacklokRepos/src/README.MD", + code="Changes are not immediate", + file_extension=".md", + ), + ], + ), + ], +) +def test_extract_snippets_require_filepath(test_case: CodeSnippetTest): + extractor = DefaultCodeSnippetExtractor() + snippets = extractor.extract_snippets(test_case.input_message, require_filepath=True) + _evaluate_actual_snippets(snippets, test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # Analyze folder from Cline + CodeSnippetTest( + input_message=""" +[TASK RESUMPTION] This task was interrupted 1 day ago. +It may or may not be complete, so please reassess the task context. +Be aware that the project state may have changed since then. +The current working directory is now '/Users/aponcedeleonch/StacklokRepos'. +If the task has not been completed, retry the last step before interruption +and proceed with completing the task. + +Note: If you previously attempted a tool use that the user did not provide a result for, +you should assume the tool use was not successful and assess whether you should retry. +If the last tool was a browser_action, the browser has been closed and you must launch a new +browser if needed. + +New instructions for task continuation: + +please evaluate my folder 'codegate/tests/pipeline/extract_snippets/' (see below for folder content) + and suggest improvements + + + +├── __pycache__/ +└── test_extract_snippets.py + + +from typing import List, NamedTuple + +import pytest + +from codegate.pipeline.extract_snippets.extract_snippets import CodeSnippet, CodeSnippetExtractor + + +class CodeSnippetTest(NamedTuple): + input_message: str + expected_count: int + expected: List[CodeSnippet] + + +def _evaluate_actual_snippets(actual_snips: List[CodeSnippet], expected_snips: CodeSnippetTest): + assert len(actual_snips) == expected_snips.expected_count + + for expected, actual in zip(expected_snips.expected, actual_snips): + assert actual.language == expected.language + assert actual.filepath == expected.filepath + assert actual.file_extension == expected.file_extension + assert expected.code in actual.code + + + + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="codegate/tests/pipeline/extract_snippets/test_extract_snippets.py", + code="def _evaluate_actual_snippets", + file_extension=".py", + ), + ], + ), + # Several snippets from Cline + CodeSnippetTest( + input_message=''' +[ +now please analyze the folder 'codegate/src/codegate/api/' (see below for folder content) + + + +├── __init__.py +├── __pycache__/ +├── v1.py +├── v1_models.py +└── v1_processing.py + + + + + + +from typing import List, Optional +from uuid import UUID + +import requests +import structlog + +v1 = APIRouter() +wscrud = crud.WorkspaceCrud() +pcrud = provendcrud.ProviderCrud() + + + + +import datetime +from enum import Enum + + +class Conversation(pydantic.BaseModel): + """ + Represents a conversation. + """ + + question_answers: List[QuestionAnswer] + provider: Optional[str] + type: QuestionType + chat_id: str + conversation_timestamp: datetime.datetime + token_usage_agg: Optional[TokenUsageAggregate] + + + + +import asyncio +import json +import re +from collections import defaultdict + +async def _process_prompt_output_to_partial_qa( + prompts_outputs: List[GetPromptWithOutputsRow], +) -> List[PartialQuestionAnswer]: + """ + Process the prompts and outputs to PartialQuestionAnswer objects. + """ + # Parse the prompts and outputs in parallel + async with asyncio.TaskGroup() as tg: + tasks = [tg.create_task(_get_partial_question_answer(row)) for row in prompts_outputs] + return [task.result() for task in tasks if task.result() is not None] + + + + ''', + expected_count=4, + expected=[ + CodeSnippet( + language="python", + filepath="codegate/src/codegate/api/__init__.py", + code="", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="codegate/src/codegate/api/v1.py", + code="v1 = APIRouter()", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="codegate/src/codegate/api/v1_models.py", + code="class Conversation(pydantic.BaseModel):", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="codegate/src/codegate/api/v1_processing.py", + code="async def _process_prompt_output_to_partial_qa(", + file_extension=".py", + ), + ], + ), + ], +) +def test_extract_cline_snippets(test_case: CodeSnippetTest): + extractor = ClineCodeSnippetExtractor() + snippets = extractor.extract_snippets(test_case.input_message, require_filepath=True) + _evaluate_actual_snippets(snippets, test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # Analyze summary extracts from Aider + CodeSnippetTest( + input_message=''' +Here are summaries of some files present in my git repository. +Do not propose changes to these files, treat them as *read-only*. +If you need to edit any of these files, ask me to *add them to the chat* first. + +src/codegate/codegate_logging.py: +⋮... +│def serialize_for_logging(obj: Any) -> Any: +⋮... + +src/codegate/config.py: +⋮... +│@dataclass +│class Config: +│ """Application configuration with priority resolution.""" +│ +⋮... +│ @classmethod +│ def from_file(cls, config_path: Union[str, Path]) -> "Config": +⋮... +│ @classmethod +│ def load( +│ cls, +│ config_path: Optional[Union[str, Path]] = None, +│ prompts_path: Optional[Union[str, Path]] = None, +│ cli_port: Optional[int] = None, +│ cli_proxy_port: Optional[int] = None, +│ cli_host: Optional[str] = None, +│ cli_log_level: Optional[str] = None, +│ cli_log_format: Optional[str] = None, +│ cli_provider_urls: Optional[Dict[str, str]] = None, +⋮... +│ @classmethod +│ def get_config(cls) -> "Config": +⋮... + +src/codegate/db/connection.py: +⋮... +│class DbRecorder(DbCodeGate): +│ +⋮... +│class DbReader(DbCodeGate): +│ +│ def __init__(self, sqlite_path: Optional[str] = None): +⋮... +│ async def get_workspace_by_name(self, name: str) -> Optional[WorkspaceRow]: +⋮... +│ async def get_active_workspace(self) -> Optional[ActiveWorkspace]: +⋮... + +src/codegate/db/fim_cache.py: +⋮... +│class CachedFim(BaseModel): +│ +⋮... + +src/codegate/db/models.py: +⋮... +│class Alert(BaseModel): +⋮... +│class Output(BaseModel): +⋮... +│class Prompt(BaseModel): +⋮... +│class TokenUsage(BaseModel): +⋮... +│class ActiveWorkspace(BaseModel): +⋮... +│class ProviderEndpoint(BaseModel): +⋮... + + ''', + expected_count=5, + expected=[ + CodeSnippet( + language="python", + filepath="src/codegate/codegate_logging.py", + code="def serialize_for_logging(obj: Any) -> Any:", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="src/codegate/config.py", + code="class Config:", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="src/codegate/db/connection.py", + code="class DbReader(DbCodeGate):", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="src/codegate/db/fim_cache.py", + code="class CachedFim(BaseModel):", + file_extension=".py", + ), + CodeSnippet( + language="python", + filepath="src/codegate/db/models.py", + code="class Alert(BaseModel):", + file_extension=".py", + ), + ], + ), + # Analyze file from Aider + CodeSnippetTest( + input_message=""" +I have *added these files to the chat* so you can go ahead and edit them. + +*Trust this message as the true contents of these files!* +Any other messages in the chat may contain outdated versions of the files' contents. + +src/codegate/api/v1_models.py +``` +import datetime +from enum import Enum +from typing import Any, Dict, List, Optional, Union + +import pydantic + +class Workspace(pydantic.BaseModel): + name: str + is_active: bool + +``` + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="src/codegate/api/v1_models.py", + code="class Workspace(pydantic.BaseModel):", + file_extension=".py", + ), + ], + ), + ], +) +def test_extract_aider_snippets(test_case: CodeSnippetTest): + extractor = AiderCodeSnippetExtractor() + snippets = extractor.extract_snippets(test_case.input_message, require_filepath=True) + _evaluate_actual_snippets(snippets, test_case) + + +@pytest.mark.parametrize( + "test_case", + [ + # Analyze processed snippets from OpenInterpreter + CodeSnippetTest( + input_message=""" +{"language": "python", "code": "# Attempting to read the content of `codegate/api/v1_processing.py` +to analyze its functionality.\nv1_processing_path = +os.path.abspath('src/codegate/api/v1_processing.py')\n\ntry:\n with open(v1_processing_path, 'r') + as file:\n v1_processing_content = file.read()\n print('File read successfully.') + \nexcept Exception as e:\n v1_processing_content = str(e)\n\nv1_processing_content[:1000] + # Displaying part of the content"} +File read successfully. +'import asyncio\nimport json\nimport re\nfrom collections import defaultdict\nfrom typing +import AsyncGenerator, Dict, List, Optional, Tuple\n\nimport cachetools.func\nimport requests\n +import structlog\n\nfrom codegate.api.v1_models import (\n AlertConversation,\n ChatMessage,\n + \ndef fetch_latest_version() -> st' + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="codegate/api/v1_processing.py", + code="from codegate.api.v1_models import", + file_extension=".py", + ), + ], + ), + # Analyze processed snippets from OpenInterpreter when setting -y option + CodeSnippetTest( + input_message=""" +{"language": "python", "code": "\n# Open and read the contents of the src/codegate/api/v1.py file\n +with open('src/codegate/api/v1.py', 'r') as file:\n content = file.read()\n\ncontent\n"} +Output truncated. Showing the last 2800 characters. +You should try again and use computer.ai.summarize(output) over the output, or break it down into +smaller steps. Run `get_last_output()[0:2800]` to see the first page. + +r as e:\n raise HTTPException(status_code=400, detail=str(e))\n except Exception:\n +logger.exception("Error while setting muxes")\n +raise HTTPException(status_code=500, detail="Internal server error")\n\n +return Response(status_code=204)\n\n\n@v1.get("/alerts_notification", tags=["Dashboard"] + """, + expected_count=1, + expected=[ + CodeSnippet( + language="python", + filepath="src/codegate/api/v1.py", + code="raise HTTPException", + file_extension=".py", + ), + ], + ), + ], +) +def test_extract_openinterpreter_snippets(test_case: CodeSnippetTest): + extractor = OpenInterpreterCodeSnippetExtractor() + snippets = extractor.extract_snippets(test_case.input_message, require_filepath=True) + _evaluate_actual_snippets(snippets, test_case) + + +@pytest.mark.parametrize( + "filepath,expected", + [ + # Standard extensions + ("file.py", "python"), + ("script.js", "javascript"), + ("code.go", "go"), + ("app.ts", "typescript"), + ("component.tsx", "typescript"), + ("program.rs", "rust"), + ("App.java", "java"), + # Case insensitive + ("FILE.PY", "python"), + ("SCRIPT.JS", "javascript"), + # Full paths + ("/path/to/file.rs", "rust"), + ("C:\\Users\\name\\file.java", "java"), + ], +) +def test_valid_extensions(filepath, expected): + extractor = DefaultCodeSnippetExtractor() + assert extractor._ecosystem_from_filepath(filepath) == expected + + +@pytest.mark.parametrize( + "filepath", + [ + # No extension + "README", + "script", + "README.txt", + # Unknown extensions + "file.xyz", + "unknown.extension", + ], +) +def test_no_or_unknown_extensions(filepath): + extractor = DefaultCodeSnippetExtractor() + assert extractor._ecosystem_from_filepath(filepath) is None + + +@pytest.mark.parametrize( + "expected_filenames,message", + [ + # Single snippet + ( + ["main.py"], + """ + ```main.py + foo + ``` + """, + ), + # Repeated snippet + ( + ["main.py"], + """ + ```main.py + foo + ``` + + ```main.py + bar + ``` + """, + ), + # Multiple snippets + ( + ["main.py", "snippets.py"], + """ + ```main.py + foo + ``` + + ```src/codegate/snippets.py + bar + ``` + """, + ), + ], +) +def test_extract_unique_snippets(expected_filenames: List[str], message: str): + extractor = DefaultCodeSnippetExtractor() + snippets = extractor.extract_unique_snippets(message) + + actual_code_hashes = snippets.keys() + assert len(actual_code_hashes) == len(expected_filenames) + assert set(actual_code_hashes) == set(expected_filenames) diff --git a/tests/pipeline/extract_snippets/test_extract_snippets.py b/tests/pipeline/extract_snippets/test_extract_snippets.py deleted file mode 100644 index 4e878330..00000000 --- a/tests/pipeline/extract_snippets/test_extract_snippets.py +++ /dev/null @@ -1,186 +0,0 @@ -from typing import List, NamedTuple - -import pytest - -from codegate.pipeline.base import CodeSnippet -from codegate.pipeline.extract_snippets.extract_snippets import ( - ecosystem_from_filepath, - extract_snippets, -) - - -class CodeSnippetTest(NamedTuple): - input_message: str - expected_count: int - expected: List[CodeSnippet] - - -@pytest.mark.parametrize( - "test_case", - [ - # Single Python code block without filename - CodeSnippetTest( - input_message=""": - Here's a Python snippet: - ``` - def hello(): - print("Hello, world!") - ``` - """, - expected_count=1, - expected=[ - CodeSnippet(language=None, filepath=None, code='print("Hello, world!")'), - ], - ), - # output code snippet with no filename - CodeSnippetTest( - input_message=""": - ```python - @app.route('/') - def hello(): - GITHUB_TOKEN="ghp_RjzIRljYij9CznoS7QAnD5RaFF6yH32073uI" - if __name__ == '__main__': - app.run() - return "Hello, Moon!" - ``` - """, - expected_count=1, - expected=[ - CodeSnippet( - language="python", - filepath=None, - code="Hello, Moon!", - ), - ], - ), - # Single Python code block - CodeSnippetTest( - input_message=""" - Here's a Python snippet: - ```hello_world.py (8-13) - def hello(): - print("Hello, world!") - ``` - """, - expected_count=1, - expected=[ - CodeSnippet( - language="python", filepath="hello_world.py", code='print("Hello, world!")' - ), - ], - ), - # Single Python code block with a language identifier - CodeSnippetTest( - input_message=""" - Here's a Python snippet: - ```py goodbye_world.py (8-13) - def hello(): - print("Goodbye, world!") - ``` - """, - expected_count=1, - expected=[ - CodeSnippet( - language="python", filepath="goodbye_world.py", code='print("Goodbye, world!")' - ), - ], - ), - # Multiple code blocks with different languages - CodeSnippetTest( - input_message=""" - Python snippet: - ```main.py - def hello(): - print("Hello") - ``` - - JavaScript snippet: - ```script.js (1-3) - function greet() { - console.log("Hi"); - } - ``` - """, - expected_count=2, - expected=[ - CodeSnippet(language="python", filepath="main.py", code='print("Hello")'), - CodeSnippet( - language="javascript", - filepath="script.js", - code='console.log("Hi");', - ), - ], - ), - # No code blocks - CodeSnippetTest( - input_message="Just a plain text message", - expected_count=0, - expected=[], - ), - # unknown language - CodeSnippetTest( - input_message=""": - Here's a Perl snippet: - ```hello_world.pl - I'm a Perl script - ``` - """, - expected_count=1, - expected=[ - CodeSnippet( - language=None, - filepath="hello_world.pl", - code="I'm a Perl script", - ), - ], - ), - ], -) -def test_extract_snippets(test_case): - snippets = extract_snippets(test_case.input_message) - - assert len(snippets) == test_case.expected_count - - for expected, actual in zip(test_case.expected, snippets): - assert actual.language == expected.language - assert actual.filepath == expected.filepath - assert expected.code in actual.code - - -@pytest.mark.parametrize( - "filepath,expected", - [ - # Standard extensions - ("file.py", "python"), - ("script.js", "javascript"), - ("code.go", "go"), - ("app.ts", "typescript"), - ("component.tsx", "typescript"), - ("program.rs", "rust"), - ("App.java", "java"), - # Case insensitive - ("FILE.PY", "python"), - ("SCRIPT.JS", "javascript"), - # Full paths - ("/path/to/file.rs", "rust"), - ("C:\\Users\\name\\file.java", "java"), - ], -) -def test_valid_extensions(filepath, expected): - assert ecosystem_from_filepath(filepath) == expected - - -@pytest.mark.parametrize( - "filepath", - [ - # No extension - "README", - "script", - "README.txt", - # Unknown extensions - "file.xyz", - "unknown.extension", - ], -) -def test_no_or_unknown_extensions(filepath): - assert ecosystem_from_filepath(filepath) is None From ffea4b4d5f14f30bbe33c710fa64edecc1a82489 Mon Sep 17 00:00:00 2001 From: Alejandro Ponce Date: Thu, 6 Feb 2025 12:09:07 +0200 Subject: [PATCH 2/2] Attended comments in PR, thanks Jakub --- .../extract_snippets/body_extractor.py | 33 ++++++++++++++----- src/codegate/extract_snippets/factory.py | 22 ++++++++++++- src/codegate/muxing/router.py | 11 +++++-- .../codegate_context_retriever/codegate.py | 8 ++--- tests/extract_snippets/test_body_extractor.py | 4 +-- 5 files changed, 58 insertions(+), 20 deletions(-) diff --git a/src/codegate/extract_snippets/body_extractor.py b/src/codegate/extract_snippets/body_extractor.py index e885017a..3a2307f7 100644 --- a/src/codegate/extract_snippets/body_extractor.py +++ b/src/codegate/extract_snippets/body_extractor.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List +from typing import List, Optional from codegate.extract_snippets.message_extractor import ( AiderCodeSnippetExtractor, @@ -10,16 +10,28 @@ ) +class BodyCodeSnippetExtractorError(Exception): + pass + + class BodyCodeSnippetExtractor(ABC): def __init__(self): # Initialize the extractor in parent class. The child classes will set the extractor. - self._snippet_extractor: CodeSnippetExtractor = None + self._snippet_extractor: Optional[CodeSnippetExtractor] = None def _extract_from_user_messages(self, data: dict) -> set[str]: - copied_data = data.copy() + """ + The method extracts the code snippets from the user messages in the data got from the + clients. + + It returns a set of filenames extracted from the code snippets. + """ + if self._snippet_extractor is None: + raise BodyCodeSnippetExtractorError("Code Extractor not set.") + filenames: List[str] = [] - for msg in copied_data.get("messages", []): + for msg in data.get("messages", []): if msg.get("role", "") == "user": extracted_snippets = self._snippet_extractor.extract_unique_snippets( msg.get("content") @@ -28,7 +40,10 @@ def _extract_from_user_messages(self, data: dict) -> set[str]: return set(filenames) @abstractmethod - def extract_unique_snippets(self, data: dict) -> set[str]: + def extract_unique_filenames(self, data: dict) -> set[str]: + """ + Extract the unique filenames from the data received by the clients (Cline, Continue, ...) + """ pass @@ -37,7 +52,7 @@ class ContinueBodySnippetExtractor(BodyCodeSnippetExtractor): def __init__(self): self._snippet_extractor = DefaultCodeSnippetExtractor() - def extract_unique_snippets(self, data: dict) -> set[str]: + def extract_unique_filenames(self, data: dict) -> set[str]: return self._extract_from_user_messages(data) @@ -46,7 +61,7 @@ class AiderBodySnippetExtractor(BodyCodeSnippetExtractor): def __init__(self): self._snippet_extractor = AiderCodeSnippetExtractor() - def extract_unique_snippets(self, data: dict) -> set[str]: + def extract_unique_filenames(self, data: dict) -> set[str]: return self._extract_from_user_messages(data) @@ -55,7 +70,7 @@ class ClineBodySnippetExtractor(BodyCodeSnippetExtractor): def __init__(self): self._snippet_extractor = ClineCodeSnippetExtractor() - def extract_unique_snippets(self, data: dict) -> set[str]: + def extract_unique_filenames(self, data: dict) -> set[str]: return self._extract_from_user_messages(data) @@ -85,7 +100,7 @@ def _extract_result_from_tool_result(self, msg: dict) -> str: """ return msg.get("content", "") - def extract_unique_snippets(self, data: dict) -> set[str]: + def extract_unique_filenames(self, data: dict) -> set[str]: messages = data.get("messages", []) if not messages: return set() diff --git a/src/codegate/extract_snippets/factory.py b/src/codegate/extract_snippets/factory.py index 3c14a4a9..5f5f0231 100644 --- a/src/codegate/extract_snippets/factory.py +++ b/src/codegate/extract_snippets/factory.py @@ -6,9 +6,16 @@ ContinueBodySnippetExtractor, OpenInterpreterBodySnippetExtractor, ) +from codegate.extract_snippets.message_extractor import ( + AiderCodeSnippetExtractor, + ClineCodeSnippetExtractor, + CodeSnippetExtractor, + DefaultCodeSnippetExtractor, + OpenInterpreterCodeSnippetExtractor, +) -class CodeSnippetExtractorFactory: +class BodyCodeExtractorFactory: @staticmethod def create_snippet_extractor(detected_client: ClientType) -> BodyCodeSnippetExtractor: @@ -19,3 +26,16 @@ def create_snippet_extractor(detected_client: ClientType) -> BodyCodeSnippetExtr ClientType.OPEN_INTERPRETER: OpenInterpreterBodySnippetExtractor(), } return mapping_client_extractor.get(detected_client, ContinueBodySnippetExtractor()) + + +class MessageCodeExtractorFactory: + + @staticmethod + def create_snippet_extractor(detected_client: ClientType) -> CodeSnippetExtractor: + mapping_client_extractor = { + ClientType.GENERIC: DefaultCodeSnippetExtractor(), + ClientType.CLINE: ClineCodeSnippetExtractor(), + ClientType.AIDER: AiderCodeSnippetExtractor(), + ClientType.OPEN_INTERPRETER: OpenInterpreterCodeSnippetExtractor(), + } + return mapping_client_extractor.get(detected_client, DefaultCodeSnippetExtractor()) diff --git a/src/codegate/muxing/router.py b/src/codegate/muxing/router.py index b75a81bf..5771e59e 100644 --- a/src/codegate/muxing/router.py +++ b/src/codegate/muxing/router.py @@ -5,7 +5,8 @@ from codegate.clients.clients import ClientType from codegate.clients.detector import DetectClient -from codegate.extract_snippets.factory import CodeSnippetExtractorFactory +from codegate.extract_snippets.body_extractor import BodyCodeSnippetExtractorError +from codegate.extract_snippets.factory import BodyCodeExtractorFactory from codegate.muxing import rulematcher from codegate.muxing.adapter import BodyAdapter, ResponseAdapter from codegate.providers.registry import ProviderRegistry @@ -42,8 +43,12 @@ def _extract_request_filenames(self, detected_client: ClientType, data: dict) -> """ Extract filenames from the request data. """ - body_extractor = CodeSnippetExtractorFactory.create_snippet_extractor(detected_client) - return body_extractor.extract_unique_snippets(data) + try: + body_extractor = BodyCodeExtractorFactory.create_snippet_extractor(detected_client) + return body_extractor.extract_unique_filenames(data) + except BodyCodeSnippetExtractorError as e: + logger.error(f"Error extracting filenames from request: {e}") + return set() async def _get_model_routes(self, filenames: set[str]) -> list[rulematcher.ModelRoute]: """ diff --git a/src/codegate/pipeline/codegate_context_retriever/codegate.py b/src/codegate/pipeline/codegate_context_retriever/codegate.py index 15a5b122..f373e874 100644 --- a/src/codegate/pipeline/codegate_context_retriever/codegate.py +++ b/src/codegate/pipeline/codegate_context_retriever/codegate.py @@ -5,7 +5,7 @@ from litellm import ChatCompletionRequest from codegate.clients.clients import ClientType -from codegate.extract_snippets.message_extractor import DefaultCodeSnippetExtractor +from codegate.extract_snippets.factory import MessageCodeExtractorFactory from codegate.pipeline.base import ( AlertSeverity, PipelineContext, @@ -25,9 +25,6 @@ class CodegateContextRetriever(PipelineStep): the word "codegate" in the user message. """ - def __init__(self): - self.extractor = DefaultCodeSnippetExtractor() - @property def name(self) -> str: """ @@ -73,7 +70,8 @@ async def process( # noqa: C901 storage_engine = StorageEngine() # Extract any code snippets - snippets = self.extractor.extract_snippets(user_message) + extractor = MessageCodeExtractorFactory.create_snippet_extractor(context.client) + snippets = extractor.extract_snippets(user_message) bad_snippet_packages = [] if len(snippets) > 0: diff --git a/tests/extract_snippets/test_body_extractor.py b/tests/extract_snippets/test_body_extractor.py index ab4aaf4e..154efc2c 100644 --- a/tests/extract_snippets/test_body_extractor.py +++ b/tests/extract_snippets/test_body_extractor.py @@ -64,7 +64,7 @@ def _evaluate_actual_filenames(filenames: set[str], test_case: BodyCodeSnippetTe ) def test_body_extract_openinterpreter_snippets(test_case: BodyCodeSnippetTest): extractor = OpenInterpreterBodySnippetExtractor() - filenames = extractor.extract_unique_snippets(test_case.input_body_dict) + filenames = extractor.extract_unique_filenames(test_case.input_body_dict) _evaluate_actual_filenames(filenames, test_case) @@ -155,5 +155,5 @@ async def _process_prompt_output_to_partial_qa( ) def test_body_extract_cline_snippets(test_case: BodyCodeSnippetTest): extractor = ClineBodySnippetExtractor() - filenames = extractor.extract_unique_snippets(test_case.input_body_dict) + filenames = extractor.extract_unique_filenames(test_case.input_body_dict) _evaluate_actual_filenames(filenames, test_case)