In [1]:
from utils import query_raven, build_raven_prompt
import requests
from dotenv import load_dotenv
import os

_ = load_dotenv()

In [2]:
def get_age(age: int) -> int:
    """
    Validates and returns the age value.

    Args:
        age (int): The age of the client. Must be a non-negative integer. If not provided, defaults to None.

    Returns:
        int: The validated age. If no age is provided, None is returned.

    Raises:
        ValueError: If `age` is provided and is a negative integer.
    """
    if age is not None and age < 0:
        return None

    return age

In [3]:
def get_gender(gender: str) -> str:
    """
    Will return the gender of the client.

    Args:
        gender (str): the gender of the client. Can only be one of the following: 'Male', 'Female'. If not provided, defaults to None.

    Returns:
        str: The gender of the client.

    Raises:
        ValueError: If `gender` is not one of the valid options.
    """

    if gender is not None and gender not in ["Male", "Female"]:
        return None

    return gender

In [4]:
def get_most_frequent_sector_name(most_frequent_sector_name: str) -> str:
    """
    Will return the most frequent sector that the clients orders are in.

    Args:
        most_frequent_sector_name (str): The most frequent sector that the clients orders are in. Can only be one of the following:
        'Industries', 'Financials', 'Real Estate', 'Materials', 'Energy',
       'INVESTMENT', 'Consumer Discretionary', 'INDUSTRIAL',
       'Information Technology', 'Health Care', 'Consumer Staples',
       'REAL ESTATE', 'Telecommunication Services', 'Basic Materials',
       'Others', 'FOOD', 'Tourism', 'Telecommunications', 'SERVICES'.

    Returns:
        str: The sector name if valid. If no sector name is provided, None is returned.

    Raises:
        ValueError: If `most_frequent_sector_name` is not one of the specified values.
    """
    valid_sectors = [
        "Industries",
        "Financials",
        "Real Estate",
        "Materials",
        "Energy",
        "INVESTMENT",
        "Consumer Discretionary",
        "INDUSTRIAL",
        "Information Technology",
        "Health Care",
        "Consumer Staples",
        "REAL ESTATE",
        "Telecommunication Services",
        "Basic Materials",
        "Others",
        "FOOD",
        "Tourism",
        "Telecommunications",
        "SERVICES",
    ]
    if most_frequent_sector_name and most_frequent_sector_name not in valid_sectors:
        return None

    return most_frequent_sector_name

In [5]:
def get_risk_rate(risk_rate: str) -> str:
    """
    Will return the risk rate of the client

    Args:
        risk_rate (str): The risk rate of the client. Can only be one of the following: 'High', 'Low', 'Medium', 'Not Assigned'

    Returns:
        str: A valid category for the risk rate of the client

    Raises:
        ValueError: If `risk_rate` is not one of the valid options.

    """
    if risk_rate is not None and risk_rate not in [
        "High",
        "Low",
        "Medium",
        "Not Assigned",
    ]:
        return None

    return risk_rate

In [6]:
def get_completed_orders(completed_orders: str) -> str:
    """
    Will return the completed orders of the client

    Args:
        completed_orders (str): The completed orders of the client. can only be one of the following:
        'All', 'More Than Half', 'Less Than Half', 'None'

    Returns:
        str: A valid category that shows the trend of the clients completed orders

    Raises:
        ValueError: If `completed_orders` is not one of the valid options.
    """
    if completed_orders and completed_orders not in [
        "All",
        "More Than Half",
        "Less Than Half",
        "None",
    ]:
        return None

    return completed_orders

In [7]:
def get_canceled_orders(canceled_orders: str) -> str:
    """
    Will return the canceled orders of the client

    Args:
        canceled_orders (str): The canceled orders of the client. Can only be one of the following:
        'All', 'Most', 'Moderate', 'Little', 'None'

    Returns:
        str: A valid category that shows the trend of the clients canceled orders

    Raises:
        ValueError: If `canceled_orders` is not one of the valid options.
    """
    if canceled_orders and canceled_orders not in [
        "All",
        "Most",
        "Moderate",
        "Little",
        "None",
    ]:
        return None

    return canceled_orders

In [8]:
def get_average_price(average_price: float) -> float:
    """
    Will return the average price of the clients orders

    Args:
        average_price (int): The average price of the clients orders

    Returns:
        float: The validated average price of the clients orders

    Raises:
        ValueError: If `average_price` is a negative value
    """
    if average_price and average_price < 0:
        return None

    return average_price

In [9]:
def get_most_frequent_order_type(most_frequent_order_type: str) -> str:
    """
    Will return the most frequent order type of the client

    Args:
        most_frequent_order_type (str): The most frequent order type of the client. Can only be one of the following:
        'Buy', 'Sell'

    Returns:
        str: A valid category for the most frequent order type of the client

    Raises:
        ValueError: If `most_frequent_order_type` is not one of the valid options.
    """
    if most_frequent_order_type and most_frequent_order_type not in ["Buy", "Sell"]:
        return None

    return most_frequent_order_type

In [10]:
def get_most_frequent_execution_status(
    most_frequent_execution_status: str,
) -> str:
    """
    Will return the most frequent execution status of the clients orders

    Args:
        most_frequent_execution_status (str): The most frequent execution status of the clients orders.
        Can only be one of the following: 'Executed', 'Not Executed', 'Partially Executed'

    Returns:
        str: A valid category for the most frequent execution status of the clients orders

    Raises:
        ValueError: If `most_frequent_execution_status` is not one of the valid options.
    """
    if most_frequent_execution_status and most_frequent_execution_status not in [
        "Executed",
        "Not Executed",
        "Partially Executed",
    ]:
        return None

    return most_frequent_execution_status

In [11]:
def get_avg_order_rate_difference(avg_order_rate_difference: str) -> str:
    """
    Will return the change in the client's order activity

    Args:
        avg_order_rate_difference (str): The change in the client's order activity.
        Can only be one of the following: 'Increased', 'Decreased', 'Constant'

    Returns:
        str: A valid category for the change in the client's order activity

    Raises:
        ValueError: If `avg_order_rate__difference` is not one of the valid options.
    """
    if avg_order_rate_difference and avg_order_rate_difference not in [
        "Increased",
        "Decreased",
        "Constant",
    ]:
        return None

    return avg_order_rate_difference

In [12]:
def get_avg_order_quantity_rate_difference(
    avg_order_quantity_rate_difference: str,
) -> str:
    """
    Will return the change in the client's order quantity

    Args:
        avg_order_quantity_rate_difference (str): The change in the client's order quantity.
        Can only be one of the following: 'Increased', 'Decreased', 'Constant'

    Returns:
        str: A valid category for the change in the client's order quantity

    Raises:
        ValueError: If `avg_order_quantity_rate__difference` is not one of the valid options.
    """
    if (
        avg_order_quantity_rate_difference
        and avg_order_quantity_rate_difference
        not in [
            "Increased",
            "Decreased",
            "Constant",
        ]
    ):
        return None
    return avg_order_quantity_rate_difference

In [13]:
def handle_unspecified_client_data(missing_client_data: list[str]) -> str:
    """
    Handles the case where client details are not provided by the user.

    Args:
        missing_client_data (list[str]): The list of the missing client details.

    Returns:
        str: A message prompting the user to provide the missing arguments.
    """
    return (
        f"Please provide the following missing values: {', '.join(missing_client_data)}"
    )

In [14]:
def no_relevant_function(prompt: str) -> dict:
    """
    Call this when no other provided function can be called to answer the user query.

    Args:
       prompt: The prompt that cannot be answered by any other function calls.
    """
    no_function_calling_prompt = f"""
    <s> [INST] {prompt} [/INST]
    <s> [INST] I am called Raven. How can i assist you today?[/INST]
    """
    return {"message": query_raven(no_function_calling_prompt)}

In [15]:
def unknown_arguments(unknown_arg: str) -> str | int:
    """
    Provides default values for unknown or unspecified client details in a user query.

    This function returns a default value for specific client details that are not known
    or provided by the user.

    Args:
        unknown_arg (str): The name of the unknown or unspecified client detail. Could be one of the following:
        [
            age,
            risk_rate,
            gender,
            completed_orders,
            canceled_orders,
            avg_price,
            most_frequent_order_type,
            most_frequent_execution_status,
            most_frequent_sector_name,
            avg_order_rate__difference,
            avg_order_quantity_rate__difference
        ]

    Returns:
        str: The default value corresponding to the provided unknown client detail.
    """
    if unknown_arg == "age":
        return 40
    if unknown_arg == "gender":
        return "Male"
    if unknown_arg == "most_frequent_sector_name":
        return "Financials"
    if unknown_arg == "risk_rate":
        return "Not Assigned"
    if unknown_arg == "completed_orders":
        return "Most"
    if unknown_arg == "canceled_orders":
        return "Moderate"
    if unknown_arg == "avg_price":
        return 9.56
    if unknown_arg == "most_frequent_order_type":
        return "Sell"
    if unknown_arg == "most_frequent_execution_status":
        return "Executed"
    if unknown_arg == "avg_order_rate_difference":
        return "constant"
    if unknown_arg == "avg_order_quantity_rate_difference":
        return "constant"

In [16]:
def construct_client_dict(
    age: int = None,
    risk_rate: str = None,
    gender: str = None,
    completed_orders: str = None,
    canceled_orders: str = None,
    average_price: float = None,
    most_frequent_order_type: str = None,
    most_frequent_execution_status: str = None,
    most_frequent_sector_name: str = None,
    avg_order_rate_difference: str = None,
    avg_order_quantity_rate_difference: str = None,
) -> dict:
    """
    Constructs a dict with all the client details specified in the user query.

    Args:
        age (int, optional): The age of the client. If not provided, defaults to None.
        risk_rate (str, optional): The risk rate of the client. If not provided, defaults to None.
        gender (str, optional): The gender of the client. If not provided, defaults to None.
        completed_orders (str, optional): The completed orders of the client. If not provided, defaults to None.
        canceled_orders_ratio (str, optional): The canceled orders ratio of the client. If not provided, defaults to None.
        average_price (float, optional): The average price of the clients orders. If not provided, defaults to None.
        most_frequent_order_type (str, optional): The most frequent order type of the client. If not provided, defaults to None.
        most_frequent_execution_status (str, optional): The most frequent execution status of the clients orders. If not provided, defaults to None.
        most_frequent_sector_name (str, optional): The most frequent sector that the clients orders are in. If not provided, defaults to None.
        avg_order_rate_difference (str, optional): The change in the client's order activity. If not provided, defaults to None.
        avg_order_quantity_rate_difference (str, optional): The change in the client's order quantity. If not provided, defaults to None.

    Returns:
        dict: The client data.
    """

    provided_params = locals().copy()

    message = ""
    missing_args = [key for key, value in provided_params.items() if value is None]
    if missing_args:
        message = handle_unspecified_client_data(missing_args)

    return {"client_data": provided_params, "message": message}

In [17]:
class Chatbot:
    def __init__(self):
        self._history = []

    def _add_to_history(self, bot_response: dict, user_input: str):
        if "client_data" in bot_response.keys():
            self._history.append({"response": bot_response, "input": user_input})

    def _get_conversation(self) -> str:
        conversation = ""
        for hist in self._history:
            conversation += f"""
Available client data: {hist["response"]["client_data"]}
User input: {hist["input"]}
Bot: {hist["response"]["message"]}
"""
        return conversation

    def _generate_prompt(self, user_input: str) -> str:
        return build_raven_prompt(
            [
                no_relevant_function,
                construct_client_dict,
                unknown_arguments,
                get_age,
                get_risk_rate,
                get_gender,
                get_completed_orders,
                get_canceled_orders,
                get_average_price,
                get_most_frequent_order_type,
                get_most_frequent_execution_status,
                get_most_frequent_sector_name,
                get_avg_order_rate_difference,
                get_avg_order_quantity_rate_difference,
            ],
            self._get_conversation(),
            user_input,
        )

    def _get_response(self, prompt: str) -> dict:
        func = query_raven(prompt)
        if func is None:
            return {"message": "Sorry, I couldn't understand your request."}

        try:
            response = eval(func)

            return response
        except Exception as e:
            return {"message": e}

    def _prepare_client_data(self, client_data: dict) -> dict:
        temp_client_data = client_data.copy()
        temp_client_data["gender"] = 1 if temp_client_data["gender"] == "Male" else 0
        key_mapping = {
            "age": "Age",
            "gender": "IsMale",
            "average_price": "AvgPrice",
            "risk_rate": "RiskRate",
            "avg_order_rate_difference": "AvgOrderRate_Difference",
            "avg_order_quantity_rate_difference": "AvgQuantityOrderedRate_Difference",
            "completed_orders": "CompletedOrdersRatio",
            "canceled_orders": "CanceledOrdersRatio",
            "most_frequent_order_type": "Most_Frequent_OrderType",
            "most_frequent_execution_status": "Most_Frequent_ExecutionStatus",
            "most_frequent_sector_name": "Most_Frequent_SectorName",
        }

        for old_key, new_key in key_mapping.items():
            if old_key in temp_client_data:
                temp_client_data[new_key] = temp_client_data.pop(old_key)

        return temp_client_data

    def _predict_churn(self, client_data: dict) -> str:
        cleaned_client_data = self._prepare_client_data(client_data=client_data)

        try:
            response = requests.post(
                url=os.getenv("PREDICT_CHURN_URL"), json=cleaned_client_data
            )
            if response.status_code == 200:
                data = response.json()
                if "prediction" in data.keys():
                    if data["prediction"] == 1:
                        return "The client will most likely churn"
                    elif data["prediction"] == 0:
                        return "The client will not churn"
        except Exception as e:
            return f"Request failed: {e}"

    def handle_input(self, user_input: str) -> dict:
        prompt = self._generate_prompt(user_input)
        response = self._get_response(prompt)
        self._add_to_history(response, user_input)
        if response["message"] == "":
            response["message"] = self._predict_churn(response["client_data"])
        return response

In [18]:
USER_QUERY = "The client is 21 years old with a low risk rate. He mainly buys stocks in the real estate sector with an average price of 7.65 and the orders are usually fully executed. He has increased order activity and increased quantity ordered. Most of his orders are completed and little of her orders are canceled."

In [19]:
chatbot = Chatbot()

user_input = ""

while user_input != "exit":
    user_input = input("User: ")
    print(f"user input: {user_input}")
    response = chatbot.handle_input(user_input)
    print(f"Bot: {response["message"]}")

user input: The client is 21 years old with a low risk rate. He mainly buys stocks in the real estate sector with an average price of 7.65 and the orders are usually fully executed. He has increased order activity and increased quantity ordered. Most of his orders are completed and little of her orders are canceled.

Function:
def no_relevant_function(prompt: str) -> dict
    """
    Call this when no other provided function can be called to answer the user query.

    Args:
       prompt: The prompt that cannot be answered by any other function calls.
    """
    

Function:
def construct_client_dict(age: int = None, risk_rate: str = None, gender: str = None, completed_orders: str = None, canceled_orders: str = None, average_price: float = None, most_frequent_order_type: str = None, most_frequent_execution_status: str = None, most_frequent_sector_name: str = None, avg_order_rate_difference: str = None, avg_order_quantity_rate_difference: str = None) -> dict
    """
    Constructs a di

In [71]:
chatbot = Chatbot()

user_input = ""

while user_input != "exit":
    user_input = input("User: ")
    print(f"user input: {user_input}")
    response = chatbot.handle_input(user_input)
    print(f"Bot: {response["message"]}")

user input: the client is a woman
construct_client_dict(gender=get_gender())
Bot: Please provide the following missing values: age, risk_rate, gender, completed_orders, canceled_orders, average_price, most_frequent_order_type, most_frequent_execution_status, most_frequent_sector_name, avg_order_rate_difference, avg_order_quantity_rate_difference
user input: shes 16
construct_client_dict(age=get_age(age=16), risk_rate=get_risk_rate(risk_rate='High'), gender=get_gender(gender='Female'), completed_orders=get_completed_orders(completed_orders='All'), canceled_orders=get_canceled_orders(canceled_orders='All'), average_price=get_average_price(average_price=100), most_frequent_order_type=get_most_frequent_order_type(most_frequent_order_type='Buy'), most_frequent_execution_status=get_most_frequent_execution_status(most_frequent_execution_status='Executed'), most_frequent_sector_name=get_most_frequent_sector_name(most_frequent_sector_name='Financial'), avg_order_rate_difference=get_avg_order_ra