# App logic

> Contains BaseChatApp and large language model integration logic. Implements the core functionality indepent of the UI.

In [None]:
#| default_exp app

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from typing import Protocol, runtime_checkable, Generator, List
from openai import OpenAI
from ollama import Client as OllamaSDK
from ollama import AsyncClient as AsyncOllamaSDK

from gradiochat.config import ModelConfig, Message, ChatAppConfig

  from .autonotebook import tqdm as notebook_tqdm


#### Import statement

```python
from gradiochat.app import *
```

## Explanation of new code

I used the `Protocol` and `runtime_checkable` for the first time. Here's a short explainer.

#### Python's Protocol System

Protocols were introduced in Python 3.8 through PEP 544 and are part of the `typing` module. They provide a way to define interfaces that classes can implement without explicitly inheriting from them - this is called "structural typing" or "duck typing."

#### Protocols vs Abstract Base Classes (ABCs)

**Abstract Base Classes (Traditional Approach):**
- Require explicit inheritance (`class MyClass(AbstractBaseClass):`)
- Use the `@abstractmethod` decorator to mark methods that must be implemented
- Check for compatibility based on the class hierarchy (nominal typing)
- Enforce implementation at class definition time

**Protocols (New Approach):**
- Don't require inheritance - classes just need to implement the required methods
- Use the `Protocol` class and `@runtime_checkable` decorator
- Check for compatibility based on method signatures (structural typing)
- Can check compatibility at runtime with `isinstance()` if marked as `@runtime_checkable`

#### How It Works

In our code:

```python
@runtime_checkable
class LLMClientProtocol(Protocol):
    def chat_completion(self, messages: List[Message], **kwargs) -> str:
        ...
    
    def chat_completion_stream(self, messages: List[Message], **kwargs) -> Generator[str, None, None]:
        ...
```

This defines an interface that says "any class with methods named `chat_completion` and `chat_completion_stream` with these signatures is considered compatible with `LLMClientProtocol`."

The `...` in the method bodies is a special syntax that means "this method is required but not implemented here." It's similar to `pass` but specifically for protocol definitions.

#### Benefits in Our Context

1. **Flexibility**: We can create any class that implements these methods, and it will be compatible with `LLMClientProtocol` without explicitly inheriting from it.

2. **Easy Testing**: We can create mock implementations that automatically satisfy the protocol by just implementing the required methods.

3. **Type Checking**: Tools like mypy can verify that our classes implement all required methods with the correct signatures.

4. **Runtime Checking**: With `@runtime_checkable`, we can use `isinstance(obj, LLMClientProtocol)` to check if an object implements the protocol.

#### Example of Use

```python
def process_with_any_llm_client(client: LLMClientProtocol, messages: List[Message]):
    # This function will accept any object that has the required methods,
    # regardless of its class hierarchy
    response = client.chat_completion(messages)
    return response
```

This would accept our `HuggingFaceClient` or any other class that implements the required methods, without forcing them to inherit from a common base class.

## Define the general LLMClientProtocol structure

Which means it should have the methods defined in `LLMClientProtocol`.

In [None]:
#| export
@runtime_checkable
class LLMClientProtocol(Protocol):
    """Protocol defining the interface for LLM clients"""
    
    def chat_completion(self, messages: List[Message], **kwargs) -> str:
        """Generate a response from the LLM"""
        ...
    
    def chat_completion_stream(self, messages: List[Message], **kwargs) -> Generator[str, None, None]:
        """Generate a streaming response from the LLM"""
        ...

## Define the LLM Clients

This should at least follow the structure of `LLMClientProtocol` but can of course be expanded.

#### HuggingFaceClient

In [None]:
#| export
class HuggingFaceClient():
    """Client for interacting with HuggingFace models"""
    
    def __init__(self, model_config: ModelConfig):
        """Initialize the client with model configuration"""
        self.model_config = model_config

        # Default to HF Inference API if no base URL is provided
        base_url = model_config.api_base_url or "https://router.huggingface.co/hf-inference/v1"

        self.client = OpenAI(
            base_url=base_url,
            api_key=model_config.api_key or "hf_no_api_key_provided"
        )
    
    def chat_completion(self, 
            messages: List[Message], # List of messages conforming to the Message pydantic dataclass
            **kwargs
            ) -> str:
        """Generate a chat completion from the HuggingFace model"""
        # Convert our Message objects to the format expected by the OpenAI client
        openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        completion = self.client.chat.completions.create(
            model=self.model_config.model_name,
            messages=openai_messages,
            max_completion_tokens=kwargs.get("max_completion_tokens", self.model_config.max_completion_tokens),
            temperature=kwargs.get("temperature", self.model_config.temperature)
        )

        # Extract the generated text
        return completion.choices[0].message.content
    
    def chat_completion_stream(self, messages: List[Message], **kwargs) -> Generator[str, None, None]:
        """Generate a streaming chat completion"""
        # For now, use non-streaming version as a placeholder
        # We'll implement proper streaming later
        result = self.chat_completion(messages, **kwargs)
        yield result

#### TogetherAI

In [None]:
#| export
class TogetherAiClient():
    """Client for interacting with models through the TogetherAI API server
    We use the openai package"""
    
    def __init__(self, model_config: ModelConfig):
        """Initialize the client with model configuration"""
        self.model_config = model_config

        self.client = OpenAI(
            base_url=model_config.api_base_url or "https://api.together.xyz/v1", # Default to Together AI Inference API if no base URL is provided
            api_key=model_config.api_key,
        )
    
    def chat_completion(self, 
            messages: List[Message], # List of messages conforming to the Message pydantic dataclass
            **kwargs
            ) -> str:
        """Generate a chat completion from the Together AI API"""
        # Convert our Message objects to the format expected by the OpenAI client
        openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        completion = self.client.chat.completions.create(
            model=self.model_config.model_name,
            messages=openai_messages,
            max_completion_tokens=kwargs.get("max_completion_tokens", self.model_config.max_completion_tokens),
            temperature=kwargs.get("temperature", self.model_config.temperature),
            top_p=kwargs.get("top_p", self.model_config.top_p),
            stop=kwargs.get("stop", self.model_config.stop) or ["<|eot_id|>","<|eom_id|>"]
        )

        # Extract the generated text
        return completion.choices[0].message.content
    
    def chat_completion_stream(self,
            messages: List[Message], # List of messages conforming to the Message pydantic dataclass
            **kwargs) -> Generator[str, None, None]:
        """Generate a streaming chat completion"""

        openai_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        stream = self.client.chat.completions.create(
            model=self.model_config.model_name,
            messages=openai_messages,
            max_completion_tokens=kwargs.get("max_completion_tokens", self.model_config.max_completion_tokens),
            temperature=kwargs.get("temperature", self.model_config.temperature),
            top_p=kwargs.get("top_p", self.model_config.top_p),
            stop=kwargs.get("stop", self.model_config.stop) or ["<|eot_id|>","<|eom_id|>"],
            stream=True
        )

        for token in stream:
            if hasattr(token, 'choices') and token.choices[0].delta.content is not None:
                yield token.choices[0].delta.content
           

#### Local Ollama client

In [None]:
#| export
class OllamaClient():
    """Client for interacting with models through a local Ollama API server
    Uses the official Ollama Python library"""
    
    def __init__(self, model_config: ModelConfig):
        """Initialize the client with model configuration"""
        self.model_config = model_config

        # Extract host from api_base_url or use default
        host = model_config.api_base_url or "http://localhost:11434"
        
        # Create Ollama client
        self.client = OllamaSDK(host=host)
    
    def chat_completion(self, 
            messages: List[Message], # List of messages conforming to the Message pydantic dataclass
            **kwargs
            ) -> str:
        """Generate a chat completion from the Ollama API"""
        # Convert our Message objects to the format expected by the Ollama client
        ollama_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        # Prepare parameters
        params = {
            "model": self.model_config.model_name,
            "messages": ollama_messages,
            "options": {
                "temperature": kwargs.get("temperature", self.model_config.temperature),
                "top_p": kwargs.get("top_p", self.model_config.top_p),
            }
        }
        
        # Add max_tokens if specified
        max_tokens = kwargs.get("max_completion_tokens", self.model_config.max_completion_tokens)
        if max_tokens is not None:
            params["options"]["num_predict"] = max_tokens
            
        # Add stop sequences if specified
        stop = kwargs.get("stop", self.model_config.stop)
        if stop is not None:
            params["options"]["stop"] = stop

        # Call the Ollama API
        response = self.client.chat(**params)

        # Extract the generated text
        return response.message.content
    
    def chat_completion_stream(self,
            messages: List[Message], # List of messages conforming to the Message pydantic dataclass
            **kwargs) -> Generator[str, None, None]:
        """Generate a streaming chat completion"""
        # Convert our Message objects to the format expected by the Ollama client
        ollama_messages = [{"role": msg.role, "content": msg.content} for msg in messages]

        # Prepare parameters
        params = {
            "model": self.model_config.model_name,
            "messages": ollama_messages,
            "stream": True,
            "options": {
                "temperature": kwargs.get("temperature", self.model_config.temperature),
                "top_p": kwargs.get("top_p", self.model_config.top_p),
            }
        }
        
        # Add max_tokens if specified
        max_tokens = kwargs.get("max_completion_tokens", self.model_config.max_completion_tokens)
        if max_tokens is not None:
            params["options"]["num_predict"] = max_tokens
            
        # Add stop sequences if specified
        stop = kwargs.get("stop", self.model_config.stop)
        if stop is not None:
            params["options"]["stop"] = stop

        # Call the Ollama API with streaming
        stream = self.client.chat(**params)

        # Yield each chunk of content
        for chunk in stream:
            if chunk.message and chunk.message.content:
                yield chunk.message.content



## Create the LLM client

This function creates the client using the available LLM Client classes.
It gets the provider from the `model_config`. If it finds a LLM Client Class for this provider, it returns that client. If it doesn't find a LLM Client Class for that provider, it returns a ValueError.

In [None]:
#| export
def create_llm_client(model_config: ModelConfig) -> LLMClientProtocol:
    """
    Factory function to create an LLM client based on the provider.
    """
    if model_config.provider.lower() == "huggingface":
        return HuggingFaceClient(model_config)
    if model_config.provider.lower() == "togetherai":
        return TogetherAiClient(model_config)
    if model_config.provider.lower() == "ollama":
        return OllamaClient(model_config)
    else:
        raise ValueError(f"Unsupported provider: {model_config.provider}")

## The internal logic of the chat app

Now the `BaseChatApp` class is defined. This class is used to instantiate the properties en methods for the internal workings of the chat app. The UI is defined in the `ui` module.

In [None]:
#| export
class BaseChatApp:
    """Base class for creating configurable chat applications with Gradio"""
    
    def __init__(self, config: ChatAppConfig):
        """Initialize the chat application"""
        self.config = config
        self.chat_history = []
        self._load_context()
        self.client = create_llm_client(config.model)
        
    def _load_context(self) -> None:
        """Load context from markdown files"""
        self.context_text = ""
        for file_path in self.config.context_files:
            if file_path.exists() and file_path.is_file():
                with open(file_path, 'r', encoding='utf-8') as f:
                    self.context_text += f.read() + "\n\n"
    
    def prepare_messages(self, user_message: str) -> List[Message]:
        """Prepare the messages for the LLM, including system prompt and chat history"""
        messages = []
        
        # Add system message with prompt and context
        system_content = self.config.system_prompt
        if self.context_text:
            system_content += f"\n\nAdditional information: {self.context_text}"
        
        messages.append(Message(role="system", content=system_content))
        
        # Add chat history
        for msg in self.chat_history:
            messages.append(Message(role=msg['role'], content=msg['content']))
        
        # Add current user message
        messages.append(Message(role="user", content=user_message))
        
        return messages
    
    def generate_response(self, user_message: str, **kwargs) -> str:
        """Generate a response to the user message"""
        messages = self.prepare_messages(user_message)
        return self.client.chat_completion(messages, **kwargs)
    
    def generate_stream(self, user_message: str, **kwargs) -> Generator[str, None, None]:
        """Generate a streaming response to the user message"""
        messages = self.prepare_messages(user_message)
        return self.client.chat_completion_stream(messages, **kwargs)

NameError: name 'ChatAppConfig' is not defined

Create a HuggingFace test model config

In [None]:
#| eval: false
# Eval set to false, because the api key is stored in .env and thus can't be found when
# nbdev_test is run
hf_config = ModelConfig(
    model_name="mistralai/Mistral-7B-Instruct-v0.3", # "Qwen/QwQ-32B" is another possibility, but with vision you need another messages format
    provider="huggingface",
    api_key_env_var="HF_API_KEY",
    api_base_url="https://router.huggingface.co/hf-inference/v1",
    max_completion_tokens=100,
    temperature=0.7
)

# Create the client
client = create_llm_client(hf_config)

Create a Together AI test model config

In [None]:
#| eval: false
# Eval set to false, because the api key is stored in .env and thus can't be found when
# nbdev_test is run
ta_config = ModelConfig(
    # model_name="mistralai/Mistral-Nemo-Instruct-2407",
    model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
    provider="togetherai",
    api_key_env_var="TG_API_KEY",
)

# Create the client
client = create_llm_client(ta_config)

NameError: name 'TogetherAiClient' is not defined

Create a Ollama test model config

In [None]:
#| eval: false
# Eval set to false, because the api key is stored in .env and thus can't be found when
# nbdev_test is run
olla_config = ModelConfig(
    model_name="nchapman/ministral-8b-instruct-2410",
    provider="ollama",
    api_key_env_var="OLLAMA_API_KEY",
)

# Create the client
client = create_llm_client(olla_config)

NameError: name 'OllamaClient' is not defined

In [None]:
test_messages = [
    Message(role="system", content="You are Aurelius Augustinus, helping me to think deeply and be humble and thankfull."),
    Message(role="user", content="Why should I engage with the people around me?")
]
# Test with a simple prompt
try:
    response = client.chat_completion(test_messages)
    print(f"Response received: {response[:100]}...")
except Exception as e:
    print(f"Error: {e}")

# Test with overriden parameters
try:
    print("\nTesting with overridden parameters:")
    response = client.chat_completion(test_messages, max_completion_tokens=50, temperature=0.9)
    print(f"Response: {response[:100]}...")  # Show first 100 chars
except Exception as e:
    print(f"Error: {e}")

Error: Error code: 402 - {'error': 'You have exceeded your monthly included credits for Inference Providers. Subscribe to PRO to get 20x more monthly included credits.'}

Testing with overridden parameters:
Error: Error code: 402 - {'error': 'You have exceeded your monthly included credits for Inference Providers. Subscribe to PRO to get 20x more monthly included credits.'}


In [None]:
#| hide
import nbdev; nbdev.nbdev_export()