In [1]:
from utils import query_raven, build_raven_prompt
import requests

In [2]:
def get_age(age: int = None) -> int:
    """
    Validates and returns the age value.

    Args:
        age (int, optional): The age of the client. Must be a non-negative integer. If not provided, defaults to None.

    Returns:
        int: The validated age. If no age is provided, None is returned.

    Raises:
        ValueError: If `age` is provided and is a negative integer.
    """
    if age is not None and age < 0:
        raise ValueError("age must be a non-negative integer.")

    return age

In [93]:
def get_sector_name(sector_name: str) -> str:
    """
    Gets the client's sector, and checks if it's a valid sector.

    Args:
        sector_name (str): The sector in which the client works. Must match exactly one of the following sectors:
                           ["financial", "industrial", "technology"].

    Returns:
        str: the sector name if valid.

    Raises:
        ValueError: If `sector_name` is not one of the specified values.
    """

    valid_sectors = ["financial", "industrial", "technology"]

    if sector_name not in valid_sectors:
        raise ValueError(f"sector_name must be one of the following: {valid_sectors}")

    return sector_name

In [4]:
def handle_unspecified_arguments(missing_args: list[str]) -> list[str]:
    """
    Handles the case where client details are missing.

    Args:
        missing_args (list[str]): The list of the missing client details.

    Returns:
        str: A message prompting the user to provide the missing arguments.

    Example:
        >>> handle_unspecified_arguments(['sector name'])
        ['sector name']
    """
    return f"Please provide the following missing arguments: {', '.join(missing_args)}"

In [5]:
def fallback(prompt: str) -> str:
    """
    Processes a general or fallback prompt and returns an appropriate response.

    This function is used to handle user inputs that do not correspond to specific function calls.
    It formats the prompt in a way that the LLM can understand and generates a response indicating
    that it is an assistant here to help with finding out whether a client will churn or not.

    Args:
        prompt (str): The user's input prompt which does not require a specific function call.

    Returns:
        str: The response generated by the LLM for the given prompt.

    Example:
        >>> fallback("Hi, how are you?")
        'I am an assistant here to help you determine if a client will churn. How can I assist you today?'

        >>> fallback("Can I ask a question?")
        'I am an assistant here to help you determine if a client will churn. How can I assist you today?'
    """
    no_function_calling_prompt = f"""
    <s> [INST] {prompt} [/INST]
    <s> [INST] I am called Raven. How can i assist you today?[/INST]
    """
    return query_raven(no_function_calling_prompt)

In [58]:
def unknown_arguments(unknown_arg: str) -> str:
    """
    Provides default values for unknown or unspecified arguments.

    This function returns a default value for specific arguments that are not known
    or provided by the user. If the argument is "age", it returns a default age of 25.
    If the argument is "sector_name", it returns a default sector name of "financial".

    Args:
        unknown_arg (str): The name of the unknown or unspecified argument. This should be
                           one of the following: "age" or "sector_name".

    Returns:
        str: The default value corresponding to the provided unknown argument.

    Raises:
        ValueError: If `unknown_arg` is not one of the recognized values ("age" or "sector_name").

    Example:
        >>> unknown_arguments("age")
        '25'

        >>> unknown_arguments("sector_name")
        'financial'
    """
    if unknown_arg == "age":
        return "25"
    if unknown_arg == "sector_name":
        return "financial"
    raise ValueError(f"Unknown argument: {unknown_arg}")

In [22]:
def predict_churn(age: int = None, sector_name: str = None) -> dict:
    """
    Calls an API to predict whether or not the client will churn.

    Args:
        age (int): The age of the client. Must be a non-negative integer.
        sector_name (str): The sector in which the client works.

    Returns:
        dict: The prediction from the API.
    """

    params = locals().copy()

    missing_args = [key for key, value in params.items() if value is None]
    if missing_args:
        return handle_unspecified_arguments(missing_args)

    # response = requests.post(
    #     url="http://localhost:8000/predict",
    #     headers={"Content-Type": "application/json"},
    #     data=params,
    # )
    return params

In [94]:
prompt = build_raven_prompt(
    [fallback, predict_churn, get_age, get_sector_name],
    "if a client's favorite sector is financial, will they churn?",
)

In [95]:
result = query_raven(prompt)
print(result)

predict_churn(sector_name=get_sector_name(sector_name='financial'))


In [10]:
try:
    client = eval(result)
    print(client)
except ValueError as e:
    print(e)

I used default values for them


In [106]:
class Chatbot:
    def __init__(self):
        self.history = []

    def add_to_history(self, user_input: str, bot_response: str):
        self.history.append(f"User: {user_input}")
        self.history.append(f"Bot: {bot_response}")

    def get_conversation_history(self) -> str:
        return "\n".join(self.history)

    def generate_prompt(self, user_input: str) -> str:
        history = self.get_conversation_history()
        return build_raven_prompt(
            [fallback, predict_churn, get_age, get_sector_name, unknown_arguments],
            f"{history}\nUser: {user_input}\nBot:",
        )

    def get_response(self, prompt: str) -> str:
        # Simulate sending prompt to LLM and getting a response
        # Replace this with your actual API call to the LLM
        func = query_raven(prompt)
        print(func)
        return eval(func)

    def handle_input(self, user_input: str) -> str:
        prompt = self.generate_prompt(user_input)
        # print(prompt)
        response = self.get_response(prompt)
        self.add_to_history(user_input, response)
        return response


# Initialize the chatbot
chatbot = Chatbot()

user_input = ""

while user_input != "q":
    user_input = input("input prompt: ")
    response = chatbot.handle_input(user_input)
    print(f"Bot: {response}")
# Example conversation
# user_inputs = [
#     "If a client his favorite sector is financial, will he churn?",
#     "I dont know",
# ]

# for input_text in user_inputs:
#     print(f"User: {input_text}")
#     response = chatbot.handle_input(input_text)
#     print(f"Bot: {response}")

fallback(prompt='hi')
Bot: Hello! I'm glad you're here. I'm Raven, and I'm here to help you with any questions or problems you might have. Is there anything specific you'd like to talk about or ask for help with?
predict_churn(age=unknown_arguments(unknown_arg='age'), sector_name=unknown_arguments(unknown_arg='sector_name'))
Bot: {'age': '25', 'sector_name': 'financial'}
predict_churn(age=unknown_arguments(unknown_arg='age'), sector_name=unknown_arguments(unknown_arg='sector_name'))
Bot: {'age': '25', 'sector_name': 'financial'}
