In [10]:
from utils import query_raven, build_raven_prompt
import requests

In [11]:
def get_age(age: int = None) -> int:
    """
    Validates and returns the age value.

    Args:
        age (int, optional): The age of the client. Must be a non-negative integer. If not provided, defaults to None.

    Returns:
        int: The validated age. If no age is provided, None is returned.

    Raises:
        ValueError: If `age` is provided and is a negative integer.
    """
    if age is not None and age < 0:
        raise ValueError("age must be a non-negative integer.")

    return age

In [12]:
def get_sector_name(sector_name: str = None) -> str:
    """
    Gets the client's sector name, and checks if it's a valid sector name.

    Args:
        sector_name (str): The sector in which the client works. Must match exactly one of the following sectors:
                           ["financial", "industrial", "technology"].

    Returns:
        str: the sector name if valid.

    Raises:
        ValueError: If `sector_name` is not one of the specified values.
    """

    valid_sectors = ["financial", "industrial", "technology"]

    if sector_name not in valid_sectors:
        raise ValueError(f"sector_name must be one of the following: {valid_sectors}")

    return sector_name

In [13]:
def handle_unspecified_arguments(missing_args: list[str]) -> list[str]:
    """
    Handles the case where client details are missing.

    Args:
        missing_args (list[str]): The list of the missing client details. Removes underscores.

    Returns:
        list[str]: The list of the missing client details.

    Example:
        >>> handle_unspecified_arguments(['sector name'])
        ['sector name']
    """
    raise ValueError({"missing_params": missing_args})

In [14]:
def predict_churn(age: int = None, sector_name: str = None) -> dict:
    """
    Calls an API to predict whether or not the client will churn.

    Args:
        age (int): The age of the client. Must be a non-negative integer.
        sector_name (str): The sector in which the client works.

    Returns:
        dict: The prediction from the API.
    """

    params = locals().copy()

    missing_args = [key for key, value in params.items() if value is None]
    if missing_args:
        return handle_unspecified_arguments(missing_args)

    # response = requests.post(
    #     url="http://localhost:8000/predict",
    #     headers={"Content-Type": "application/json"},
    #     data=params,
    # )

    return "0"

In [15]:
prompt = build_raven_prompt(
    [predict_churn, get_age, get_sector_name],
    "If a client, finance is their favorite sector, and they are 25 years old will they churn?",
)

In [16]:
result = query_raven(prompt)
print(result)

predict_churn(sector_name=get_sector_name(sector_name='financial'), age=get_age(age=25))


In [17]:
try:
    client = eval(result)
    print(client)
except ValueError as e:
    print(e)

0
