In [2]:
import os
import json
from groq import Groq
from dotenv import find_dotenv, load_dotenv

load_dotenv(find_dotenv())

True

In [3]:
client = Groq(api_key=os.environ["GROQ_API_KEY"])

## System prompt

In [4]:
SYSTEM_PROMPT = """\
You are part of a system that teaches programming.
Choose the function that fits best the next user message.
"""

## Functions

### Tópico

In [None]:
EXPLAIN_FUNCTION_NAME = "explain_topic"
EXPLAIN_FUNCTION_DESC = """\
Creates an explanation for a programming topic for the user.
Takes into account previous explanations in order to write clearer explanations, if the topic has already been explained before.\
"""
EXPLAIN_OBS_DESC = """\
Some observations about the previous messages that should be considered when writing the explanation.\
"""

create_explanation_function = {
    "type": "function",
    "function": {
        "name": EXPLAIN_FUNCTION_NAME,
        "description": EXPLAIN_FUNCTION_DESC,
        "parameters": {
            "type": "object",
            "properties": {
                "observations": {"type": "string", "description": EXPLAIN_OBS_DESC}
            },
            "required": ["observations"],
        },
    },
}


# TODO
def explain_topic(*args, **kwargs):
    return kwargs

### Questão

In [None]:
QUESTION_FUNCTION_NAME = "create_topic_question"
QUESTION_FUNCTION_DESC = """\
Call this tool to create a question for the user about the last topic.
ALWAYS call this tool ONLY if the user said that they already understood the previous explanation.\
"""
QUESTION_OBS_DESC = """\
Some observations about the previous messages that should be considered when writing the question.\
"""

create_question_function = {
    "type": "function",
    "function": {
        "name": QUESTION_FUNCTION_NAME,
        "description": QUESTION_FUNCTION_DESC,
        "parameters": {
            "type": "object",
            "properties": {
                "observations": {"type": "string", "description": QUESTION_OBS_DESC}
            },
            "required": ["observations"],
        },
    },
}


# TODO
def create_topic_question(*args, **kwargs):
    return kwargs

### Avalia questão

In [None]:
EVALUATE_FUNCTION_NAME = "evaluate_user_answer"
EVALUATE_FUNCTION_DESC = """\
Call this tool to decide if the answer for the last question is correct.
ALWAYS call this tool ONLY if a question was answered the user said that they already understood the previous explanation.\
"""
EVALUATE_OBS_DESC = """\
Some observations about the previous messages that should be considered when evaluating the question.\
"""

evaluate_question_function = {
    "type": "function",
    "function": {
        "name": EVALUATE_FUNCTION_NAME,
        "description": EVALUATE_FUNCTION_DESC,
        "parameters": {
            "type": "object",
            "properties": {
                "observations": {"type": "string", "description": EVALUATE_OBS_DESC}
            },
            "required": ["observations"],
        },
    },
}


# TODO
def evaluate_user_answer(*args, **kwargs):
    return kwargs

### Atualiza estado

In [8]:
UPDATE_FUNCTION_NAME = "update_curriculum"
UPDATE_FUNCTION_DESC = """\
Call this tool to update the curriculum.
ALWAYS call this tool ONLY if the user answered the last question correctly.\
"""
UPDATE_OBS_DESC = """\
Some observations about the previous messages that should be considered when updating the curriculum.\
"""

update_curriculum_function = {
    "type": "function",
    "function": {
        "name": UPDATE_FUNCTION_NAME,
        "description": UPDATE_FUNCTION_DESC,
        "parameters": {
            "type": "object",
            "required": ["observations"],
            "properties": {
                "observations": {"type": "string", "description": UPDATE_OBS_DESC}
            },
        },
    },
}


# TODO
def update_curriculum(*args, **kwargs):
    return kwargs

### Gerar resposta

In [9]:
do_nothing_function = {
    "type": "function",
    "function": {
        "name": "response_without_tools",
        "description": "If you feel like you don't need to call a function",
        "parameters": {},
        "required": [],
    },
}

## Roteador

In [None]:
def execute_router(
    client,
    tools,
    tool_names,
    weights,
    messages,
    model="llama3-groq-70b-8192-tool-use-preview",
):
    chat_completion = client.chat.completions.create(
        messages=messages,
        model=model,
        tools=tools,
        tool_choice="required",
    )

    # TODO: Verificar argumentos
    tool_calls = sorted(
        [
            (call, weights[call.function.name])
            for call in chat_completion.choices[0].message.tool_calls
            if call.function.name in tool_names
        ],
        key=lambda x: x[1],
    )

    if tool_calls:
        (tool_call, _), *_ = tool_calls
        return tool_call

    else:
        raise ValueError()

In [None]:
def execute_tool_call(tool_call):
    args = json.loads(tool_call.function.arguments)

    if tool_call.function.name == "explain_topic":
        tool_response = explain_topic(**args)
    elif tool_call.function.name == "create_topic_question":
        tool_response = create_topic_question(**args)
    elif tool_call.function.name == "evaluate_user_answer":
        tool_response = evaluate_user_answer(**args)
    elif tool_call.function.name == "update_curriculum":
        tool_response = update_curriculum(**args)
    elif tool_call.function.name == "response_without_tools":
        tool_response = "preguica de responderkkkj"
    else:
        raise ValueError(tool_call.function.name)

    return tool_response

In [12]:
weights = {
    "response_without_tools": float("-inf"),
    "create_topic_explanation": 1.0,
}

tools = [
    create_explanation_function,
    create_question_function,
    evaluate_question_function,
    update_curriculum_function,
    do_nothing_function,
]
tool_names = [tool["function"]["name"] for tool in tools]

messages = [
    {
        "role": "system",
        "content": SYSTEM_PROMPT,
    },
    {
        "role": "user",
        "content": "Explain if else to me",
    },
]

In [None]:
execute_router(
    client=client,
    tools=tools,
    tool_names=tool_names,
    weights=weights,
    messages=messages,
)