diff --git a/llm_command_r.py b/llm_command_r.py index c935729..28bf4f4 100644 --- a/llm_command_r.py +++ b/llm_command_r.py @@ -84,6 +84,10 @@ class Options(llm.Options): description="Use web search connector", default=False, ) + base_url: Optional[str] = Field( + description="API base URL (if not the default Cohere endpoint)", + default=None, + ) def __init__(self, model_id): self.model_id = model_id @@ -101,7 +105,7 @@ def build_chat_history(self, conversation) -> List[dict]: return chat_history def execute(self, prompt, stream, response, conversation): - client = cohere.Client(self.get_key()) + client = cohere.Client(self.get_key(), base_url=prompt.options.base_url) kwargs = { "message": prompt.prompt, "model": self.model_id,