In [1]:
from utils import query_raven, build_raven_prompt
import requests

In [2]:
def get_age(age: int = None) -> int:
    """
    Validates and returns the age value.

    Args:
        age (int, optional): The age of the client. Must be a non-negative integer. If not provided, defaults to None.

    Returns:
        int: The validated age. If no age is provided, None is returned.

    Raises:
        ValueError: If `age` is provided and is a negative integer.
    """
    if age is not None and age < 0:
        raise ValueError("age must be a non-negative integer.")

    return age

In [3]:
def get_sector_name(sector_name: str = None) -> str:
    """
    Gets the client's sector name, and checks if it's a valid sector name.

    Args:
        sector_name (str): The sector in which the client works. Must match exactly one of the following sectors:
                           ["financial", "industrial", "technology"].

    Returns:
        str: the sector name if valid.

    Raises:
        ValueError: If `sector_name` is not one of the specified values.
    """

    valid_sectors = ["financial", "industrial", "technology"]

    if sector_name not in valid_sectors:
        raise ValueError(f"sector_name must be one of the following: {valid_sectors}")

    return sector_name

In [4]:
def handle_unspecified_arguments(missing_args: list[str]) -> list[str]:
    """
    Handles the case where client details are missing.

    Args:
        missing_args (list[str]): The list of the missing client details.

    Returns:
        str: A message prompting the user to provide the missing arguments.

    Example:
        >>> handle_unspecified_arguments(['sector name'])
        ['sector name']
    """
    return f"Please provide the following missing arguments: {', '.join(missing_args)}"

In [5]:
def fallback(prompt: str) -> str:
    """
    Processes a general or fallback prompt and returns an appropriate response.

    This function is used to handle user inputs that do not correspond to specific function calls.
    It formats the prompt in a way that the LLM can understand and generates a response indicating
    that it is an assistant here to help with finding out whether a client will churn or not.

    Args:
        prompt (str): The user's input prompt which does not require a specific function call.

    Returns:
        str: The response generated by the LLM for the given prompt.

    Example:
        >>> fallback("Hi, how are you?")
        'I am an assistant here to help you determine if a client will churn. How can I assist you today?'

        >>> fallback("Can I ask a question?")
        'I am an assistant here to help you determine if a client will churn. How can I assist you today?'
    """
    no_function_calling_prompt = f"""
    <s> [INST] {prompt} [/INST]
    <s> [INST] I am called Raven. How can i assist you today?[/INST]
    """
    return query_raven(no_function_calling_prompt)

In [6]:
def unknown_arguments(unknown_args: list[str]) -> str:
    """
    Handles unknown arguments by providing a default response.

    This function is used when certain arguments are not known by the user. It returns a
    message indicating that default values have been used for the unspecified arguments.

    Args:
        unknown_args (list[str]): A list of argument names that are unknown by the user.

    Returns:
        str: A message indicating that default values were used for the unknown arguments.

    Example:
        >>> unknown_arguments(["age", "sector_name"])
        'I used default values for them'
    """
    return "I used default values for them"

In [7]:
def predict_churn(age: int = None, sector_name: str = None) -> dict:
    """
    Calls an API to predict whether or not the client will churn.

    Args:
        age (int): The age of the client. Must be a non-negative integer.
        sector_name (str): The sector in which the client works.

    Returns:
        dict: The prediction from the API.
    """

    params = locals().copy()

    missing_args = [key for key, value in params.items() if value is None]
    if missing_args:
        return handle_unspecified_arguments(missing_args)

    # response = requests.post(
    #     url="http://localhost:8000/predict",
    #     headers={"Content-Type": "application/json"},
    #     data=params,
    # )

    return "0"

In [8]:
prompt = build_raven_prompt(
    [fallback, predict_churn, get_age, get_sector_name, unknown_arguments],
    "I dont know the sector_name or the age",
)

In [9]:
result = query_raven(prompt)
print(result)

unknown_arguments(unknown_args=['sector_name', 'age'])


In [10]:
try:
    client = eval(result)
    print(client)
except ValueError as e:
    print(e)

I used default values for them
