In [None]:
#|default_exp client

In [None]:
#|export

from typing import Optional
from cheshire_cat_api.config import Config
from cheshire_cat_api import CatClient
from cheshire_cat_api.models import SettingBody
import requests
import time
from queue import Queue
import json



In [None]:
#|export

class SuperCatClient:
    """
    A Wrapper around the official client for sane handling of the websockets connections

    Uses a queue to communite with the websocket thread and blocks until a response is received.
    This is needed as there is a bug in the cat that results in tools not being executed if we make a simple POST request
    """

    def __init__(self, config: Optional[Config] = None):
        self.cat_client = CatClient(config, on_message=self.on_message)
        self.cat_client.connect_ws()
        self.wait_for_connection()
        self.queue = Queue()

        self.host = self.cat_client.memory.api_client.configuration.host

    def on_message(self, message):
        # this run on the websocket thread
        try:
            message = json.loads(message)
            if message.get("type") == "chat_token":
                return
            self.queue.put(message)

        except json.JSONDecodeError as e:
            print(f"Failed to decode message: {e}")

    def wait_for_connection(self, timeout=10):
        start_time = time.time()
        while not self.cat_client.is_ws_connected:
            time.sleep(1)
            if time.time() - start_time > timeout:
                raise TimeoutError(
                    f"Failed to connect to WebSocket within timeout ({timeout} sec)."
                )

    def send(self, message):
        self.cat_client.send(message)
        return self.queue.get(10)

    def udpate_setting(self, name, value, category=""):
        setting_id = next(
            (
                s["setting_id"]
                for s in requests.get(
                    f"{self.host}/settings/",
                ).json()["settings"]
                if s["name"] == name
            )
        )
        r = requests.put(
            f"{self.host}/settings/{setting_id}",
            json={
                "name": name,
                "value": value,
                'category': category
            },
        )
        r.raise_for_status()
        return r 
    
    def udpate_llm_setting(self, llm_name, value):

        r = requests.put(
            f"{self.host}/llm/settings/{llm_name}",
            json=value
        )
        r.raise_for_status()
        return r

    def delete_episodic_memory(self):
        r = requests.delete(f"{self.host}/memory/conversation_history")
        r.raise_for_status()
        return r

    def __getattr__(self, name):
        # forward all the other calls to the official client
        if hasattr(self, "cat_client"):
            return getattr(self.cat_client, name)

    def close(self):
        self.cat_client.close()

    def __del__(self):
        self.cat_client.close()

    def __enter__(self):
        """Enter the runtime context related to this object."""
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        """Exit the runtime context and clean up resources."""
        self.close()

In [None]:
client = SuperCatClient()

In [None]:
client.udpate_setting("llm_selected", {"name": "LLMOpenAIChatConfig"})

<Response [200]>

In [None]:
client.udpate_llm_setting("LLMOpenAIChatConfig", {"model": "gpt-4"})

<Response [200]>

In [None]:
#|export
from pydantic import BaseModel, SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import ClassVar
from pyprojroot import here

In [None]:
#|export
class LLMSetting(BaseModel):
    pass

In [None]:
#|export
class LLMOpenAIChatConfig(LLMSetting):
    name: ClassVar[str] = "LLMOpenAIChatConfig"
    openai_api_key: str
    model_name: str = "gpt-5-mini"
    temperature: float = 1.0
    streaming: bool = False

In [None]:
LLMOpenAIChatConfig(openai_api_key="aa").model_copy(update={"temperature": 0.5}).model_dump()

{'openai_api_key': 'aa',
 'model_name': 'gpt-5-mini',
 'temperature': 0.5,
 'streaming': False}

In [None]:
#|export
class LLMOllamaConfig(LLMSetting):
    name: ClassVar[str] = "LLMOllamaConfig"
    base_url: str
    model: str = "llama3"
    num_ctx: int = 2048
    repeat_last_n: int = 64
    repeat_penalty: float = 1.1
    temperature: float = 1.0

In [None]:
#|export
class LLMGeminiChatConfig(LLMSetting):
    name: ClassVar[str] = "LLMGeminiChatConfig"
    google_api_key: str
    model: str = "gemini-2.5-pro-latest"
    temperature: float = 1.0
    top_p: int = 1
    top_k: int = 1
    max_output_tokens: int = 29000

In [None]:
#|export
class LLMSettings(BaseSettings):
    model_config = SettingsConfigDict(env_nested_delimiter="__", env_file=here(".env.local"))
    openai: LLMOpenAIChatConfig
    ollama: LLMOllamaConfig
    gemini: LLMGeminiChatConfig

In [None]:
settings = LLMSettings()

In [None]:
def get_llm_settings():
    r = requests.get(f"{client.host}/llm/settings/")
    r.raise_for_status()
    return r.json()

In [None]:
client.udpate_llm_setting("LLMOpenAIChatConfig", {"model": "gpt-5-mini"})

<Response [200]>

In [None]:
get_llm_settings()['selected_configuration']

'LLMOpenAIChatConfig'

In [None]:
client.udpate_llm_setting(settings.ollama.name, settings.ollama.model_copy(update={"model": "gemma3:27b"}).model_dump())

<Response [200]>

In [None]:
get_llm_settings()['selected_configuration']

'LLMOllamaConfig'

In [None]:
update_llm(settings.ollama.model_copy(update={"model": "gemma3:27b"}))

In [None]:
settings0 = requests.get(f"{client.host}/settings/").json()

In [None]:
[x for x in settings0['settings'] if x['name'] == settings.ollama.name]

[{'name': 'LLMOllamaConfig',
  'value': {'base_url': 'http://192.168.100.134:11444',
   'model': 'gemma3:27b',
   'num_ctx': 2048,
   'repeat_last_n': 64,
   'repeat_penalty': 1.1,
   'temperature': 1.0},
  'category': 'llm_factory',
  'setting_id': '588a97f1-6688-4d7f-b42a-9370844a7813',
  'updated_at': 1756724055}]

In [None]:
[x for x in settings0['settings'] if x['name'] == 'llm_selected']

[{'name': 'llm_selected',
  'value': {'name': 'LLMOllamaConfig'},
  'category': 'llm',
  'setting_id': 'f8074e79-8a31-4b56-ad4b-5c914f844763',
  'updated_at': 1756722764}]

In [None]:
settings1 = requests.get(f"{client.host}/settings/").json()

In [None]:
[x for x in settings1['settings'] if x['name'] == 'llm_selected']

[{'name': 'llm_selected',
  'value': {'name': 'LLMOpenAIChatConfig'},
  'category': 'llm',
  'setting_id': 'c677b4e2-71b0-409b-8c2c-c331c3ee7ffa',
  'updated_at': 1756723505}]

variants