Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
'command-nightly': 4096,
'deepseek/deepseek-chat': 128000, # 128K, but may be limited by config.max_model_tokens
'deepseek/deepseek-reasoner': 64000, # 64K, but may be limited by config.max_model_tokens
'openai/qwq-plus': 131072, # 131K context length, but may be limited by config.max_model_tokens
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
'meta-llama/Llama-2-7b-chat-hf': 4096,
'vertex_ai/codechat-bison': 6144,
Expand Down Expand Up @@ -193,3 +194,8 @@
"anthropic/claude-3-7-sonnet-20250219",
"claude-3-7-sonnet-20250219"
]

# Models that require streaming mode
STREAMING_REQUIRED_MODELS = [
"openai/qwq-plus"
]
101 changes: 87 additions & 14 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from litellm import acompletion
from tenacity import retry, retry_if_exception_type, retry_if_not_exception_type, stop_after_attempt

from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS
from pr_agent.algo import CLAUDE_EXTENDED_THINKING_MODELS, NO_SUPPORT_TEMPERATURE_MODELS, SUPPORT_REASONING_EFFORT_MODELS, USER_MESSAGE_ONLY_MODELS, STREAMING_REQUIRED_MODELS
from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.utils import ReasoningEffort, get_version
from pr_agent.config_loader import get_settings
Expand All @@ -15,6 +15,23 @@
OPENAI_RETRIES = 5


class MockResponse:
"""Mock response object for streaming models to enable consistent logging."""

def __init__(self, resp, finish_reason):
self._data = {
"choices": [
{
"message": {"content": resp},
"finish_reason": finish_reason
}
]
}

def dict(self):
return self._data


class LiteLLMAIHandler(BaseAiHandler):
"""
This class handles interactions with the OpenAI API for chat completions.
Expand Down Expand Up @@ -143,6 +160,9 @@ def __init__(self):
# Models that support extended thinking
self.claude_extended_thinking_models = CLAUDE_EXTENDED_THINKING_MODELS

# Models that require streaming
self.streaming_required_models = STREAMING_REQUIRED_MODELS

def _get_azure_ad_token(self):
"""
Generates an access token using Azure AD credentials from settings.
Expand Down Expand Up @@ -370,7 +390,9 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
get_logger().info(f"\nSystem prompt:\n{system}")
get_logger().info(f"\nUser prompt:\n{user}")

response = await acompletion(**kwargs)
# Get completion with automatic streaming detection
resp, finish_reason, response_obj = await self._get_completion(model, **kwargs)

except openai.RateLimitError as e:
get_logger().error(f"Rate limit error during LLM inference: {e}")
raise
Expand All @@ -380,19 +402,70 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
except Exception as e:
get_logger().warning(f"Unknown error during LLM inference: {e}")
raise openai.APIError from e
if response is None or len(response["choices"]) == 0:
raise openai.APIError
else:
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
get_logger().debug(f"\nAI response:\n{resp}")

# log the full response for debugging
response_log = self.prepare_logs(response, system, user, resp, finish_reason)
get_logger().debug("Full_response", artifact=response_log)
get_logger().debug(f"\nAI response:\n{resp}")

# for CLI debugging
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nAI response:\n{resp}")
# log the full response for debugging
response_log = self.prepare_logs(response_obj, system, user, resp, finish_reason)
get_logger().debug("Full_response", artifact=response_log)

# for CLI debugging
if get_settings().config.verbosity_level >= 2:
get_logger().info(f"\nAI response:\n{resp}")

return resp, finish_reason

async def _handle_streaming_response(self, response):
"""
Handle streaming response from acompletion and collect the full response.

Args:
response: The streaming response object from acompletion

Returns:
tuple: (full_response_content, finish_reason)
"""
full_response = ""
finish_reason = None

try:
async for chunk in response:
if chunk.choices and len(chunk.choices) > 0:
choice = chunk.choices[0]
delta = choice.delta
content = getattr(delta, 'content', None)
if content:
full_response += content
if choice.finish_reason:
finish_reason = choice.finish_reason
except Exception as e:
get_logger().error(f"Error handling streaming response: {e}")
raise

if not full_response and finish_reason is None:
get_logger().warning("Streaming response resulted in empty content with no finish reason")
raise openai.APIError("Empty streaming response received without proper completion")
elif not full_response and finish_reason:
get_logger().debug(f"Streaming response resulted in empty content but completed with finish_reason: {finish_reason}")
raise openai.APIError(f"Streaming response completed with finish_reason '{finish_reason}' but no content received")
return full_response, finish_reason

async def _get_completion(self, model, **kwargs):
"""
Wrapper that automatically handles streaming for required models.
"""
if model in self.streaming_required_models:
kwargs["stream"] = True
get_logger().info(f"Using streaming mode for model {model}")
response = await acompletion(**kwargs)
resp, finish_reason = await self._handle_streaming_response(response)
# Create MockResponse for streaming since we don't have the full response object
mock_response = MockResponse(resp, finish_reason)
return resp, finish_reason, mock_response
else:
response = await acompletion(**kwargs)
if response is None or len(response["choices"]) == 0:
raise openai.APIError
return (response["choices"][0]['message']['content'],
response["choices"][0]["finish_reason"],
response)