This code is adapted from the [official VAPO tool calling notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/prompt_optimizer/vertex_ai_prompt_optimizer_sdk_tool_calling.ipynb) (Author: Ivan Nardini)

It is intended as reference only and has not been tested to run in multiple environments.

### Authentication

In [None]:
# !gcloud auth application-default login - run this to add vertex credentials to your env


### Initialize language model

In [1]:
PROJECT_ID = "vijay-sandbox-335018"
PROJECT_NUMBER = 607718892999
REGION = "us-central1" 

In [None]:
from google.cloud import aiplatform
from vertexai.preview.generative_models import GenerativeModel
from vertexai.generative_models import FunctionDeclaration, Tool, ToolConfig

# Initialize Vertex AI
aiplatform.init(project=PROJECT_ID, location=REGION)# Initialize the model
lm = GenerativeModel(
    model_name="gemini-2.0-flash-001"
)
lm.generate_content("This is a test!")



### Define the tools

In [3]:
from typing import Any

def get_company_information_api(content: dict[str, Any]) -> str:
    "A function to simulate an API call to collect company information."

    company_overviews = {
        "AAPL": "Apple maintains a robust financial position with substantial cash reserves and consistent profitability, fueled by its strong brand and loyal customer base. However, growth is slowing and the company faces competition.",
        "ADBE": "Adobe financials are robust, driven by its successful transition to a subscription-based model for its creative and document cloud software.  Profitability and revenue growth are strong.",
        "AMD": "AMD exhibits strong financial performance, gaining market share in the CPU and GPU markets.  Revenue growth and profitability are healthy, driven by strong product offerings.",
        "AMZN": "Amazon financials are mixed, with its e-commerce business facing margin pressure while its cloud computing division (AWS) delivers strong profitability and growth. Its overall revenue remains high but profitability is a concern.",
        "ASML": "ASML boasts a strong financial position due to its monopoly in the extreme ultraviolet lithography market, essential for advanced semiconductor manufacturing.  High profitability and growth are key strengths.",
        "AVGO": "Broadcom maintains healthy financials, driven by its semiconductor and infrastructure software solutions. Acquisitions have played a role in its growth strategy, with consistent profitability and cash flow.",
        "BABA": "Alibaba financials are substantial but facing challenges from regulatory scrutiny in China and increased competition.  E-commerce revenue remains strong but growth is slowing.",
        "BKNG": "Booking Holdings financials are closely tied to the travel industry.  Revenue growth is recovering post-pandemic but profitability can fluctuate based on global travel trends.",
        "CRM": "Salesforce shows robust revenue growth from its cloud-based CRM solutions.  Profitability is improving but competition remains strong.",
        "CSCO": "Cisco financials show moderate growth, transitioning from hardware to software and services.  Profitability is stable but the company faces competition in the networking market.",
        "GOOGL": "Alphabet exhibits strong financials driven by advertising revenue, though facing regulatory scrutiny.  Diversification into other ventures provides growth opportunities but profitability varies.",
        "IBM": "IBM financials are in a state of transformation, shifting focus to hybrid cloud and AI.  Revenue growth is modest, with profitability impacted by legacy businesses.",
        "INTU": "Intuit showcases healthy financials, benefiting from its strong position in tax and financial management software.  Revenue growth and profitability are consistent, fueled by recurring subscription revenue.",
        "META": "Meta Platforms financial performance is tied closely to advertising revenue, facing headwinds from competition and changing privacy regulations.  Investments in the metaverse represent a long-term, high-risk bet.",
        "MSFT": "Microsoft demonstrates healthy financials, benefiting from diversified revenue streams including cloud computing (Azure), software, and hardware.  The company exhibits consistent growth and profitability.",
        "NFLX": "Netflix exhibits strong revenue but faces challenges in maintaining subscriber growth and managing content costs. Profitability varies, and competition in the streaming market is intense.",
        "NOW": "ServiceNow demonstrates strong financials, fueled by its cloud-based workflow automation platform.  Revenue growth and profitability are high, reflecting increased enterprise adoption.",
        "NVDA": "NVIDIA boasts strong financials, driven by its dominance in the GPU market for gaming, AI, and data centers.  High revenue growth and profitability are key strengths.",
        "ORCL": "Oracle financials are in transition, shifting towards cloud-based services. Revenue growth is moderate, and profitability remains stable.  Legacy businesses still contribute significantly.",
        "QCOM": "QUALCOMM financials show strong performance driven by its leadership in mobile chipsets and licensing.  Profitability is high, and growth is tied to the mobile market and 5G adoption.",
        "SAP": "SAP demonstrates steady financials with its enterprise software solutions.  Transition to the cloud is ongoing and impacting revenue growth and profitability.",
        "SMSN": "Samsung financials are diverse, reflecting its presence in various sectors including mobile phones, consumer electronics, and semiconductors. Profitability varies across divisions but the company holds significant cash reserves.",
        "TCEHY": "Tencent financials are driven by its dominant position in the Chinese gaming and social media market. Revenue growth is strong but regulatory risks in China impact its performance.",
        "TSLA": "Tesla financials show strong revenue growth driven by electric vehicle demand, but profitability remains volatile due to production and investment costs. The company high valuation reflects market optimism for future growth.",
        "TSM": "TSMC, a dominant player in semiconductor manufacturing, showcases robust financials fueled by high demand for its advanced chips. Profitability is strong and the company enjoys a technologically advanced position.",
    }
    return company_overviews.get(content["ticker"], "No company overwiew found")


def get_stock_price_api(content: dict[str, Any]) -> int:
    "A function to simulate an API call to collect most recent stock price for a given company."
    stock_prices = {
        "AAPL": 225,
        "ADBE": 503,
        "AMD": 134,
        "AMZN": 202,
        "ASML": 658,
        "AVGO": 164,
        "BABA": 88,
        "BKNG": 4000,
        "CRM": 325,
        "CSCO": 57,
        "GOOGL": 173,
        "IBM": 201,
        "INTU": 607,
        "META": 553,
        "MSFT": 415,
        "NFLX": 823,
        "NOW": 1000,
        "NVDA": 141,
        "ORCL": 183,
        "QCOM": 160,
        "SAP": 228,
        "SMSN": 38,
        "TCEHY": 51,
        "TSLA": 302,
        "TSM": 186,
    }
    return stock_prices.get(content["ticker"], "No stock price found")


def get_company_news_api(content: dict[str, Any]) -> str:
    "A function to simulate an API call to collect recent news for a given company."
    news_data = {
        "AAPL": "Apple unveils new iPhone, market reaction muted amid concerns about slowing growth.",
        "ADBE": "Adobe integrates AI features into Creative Suite, attracting creative professionals.",
        "AMD": "AMD gains market share in server CPUs, competing with Intel.",
        "AMZN": "Amazon stock dips after reporting lower-than-expected Q3 profits due to increased shipping costs.",
        "ASML": "ASML benefits from high demand for advanced chip manufacturing equipment.",
        "AVGO": "Broadcom announces new acquisition in the semiconductor space.",
        "BABA": "Alibaba stock faces uncertainty amid ongoing regulatory scrutiny in China.",
        "BKNG": "Booking Holdings stock recovers as travel demand rebounds post-pandemic.",
        "CRM": "Salesforce launches new AI-powered CRM tools for enterprise customers.",
        "CSCO": "Cisco stock rises after positive earnings report, focus on networking solutions.",
        "GOOGL": "Alphabet announces new AI-powered search features, aiming to compete with Microsoft.",
        "IBM": "IBM focuses on hybrid cloud solutions, showing steady growth in enterprise segment.",
        "INTU": "Intuit stock dips after announcing price increases for its tax software.",
        "META": "Meta shares rise after positive user growth figures in emerging markets.",
        "MSFT": "Microsoft expands AI integration across its product suite, boosting investor confidence.",
        "NFLX": "Netflix subscriber growth slows, competition heats up in streaming landscape.",
        "NOW": "ServiceNow sees strong growth in its cloud-based workflow automation platform.",
        "NVDA": "Nvidia stock jumps on strong earnings forecast, driven by AI demand.",
        "ORCL": "Oracle cloud revenue continues strong growth, exceeding market expectations.",
        "QCOM": "Qualcomm expands its 5G modem business, partnering with major smartphone manufacturers.",
        "SAP": "SAP cloud transition continues, but faces challenges in attracting new clients.",
        "SMSN": "Samsung unveils new foldable phones, looking to gain market share.",
        "TCEHY": "Tencent faces regulatory pressure in China, impacting investor sentiment.",
        "TSLA": "Tesla stock volatile after price cuts and production increases announced.",
        "TSM": "TSMC reports record chip demand but warns of potential supply chain disruptions.",
    }
    return news_data.get(content["ticker"], "No news available")


def get_company_sentiment_api(content: dict[str, Any]) -> str:
    "A function to simulate an API call to collect market company sentiment for a given company."

    company_sentiment = {
        "AAPL": "Neutral",
        "ADBE": "Neutral",
        "AMD": "Neutral",
        "AMZN": "Neutral",
        "ASML": "Bearish/Undervalued",
        "AVGO": "Neutral",
        "BABA": "Neutral",
        "BKNG": "Neutral",
        "CRM": "Neutral",
        "CSCO": "Neutral",
        "GOOGL": "Neutral",
        "IBM": "Neutral",
        "INTU": "Mixed/Bullish",
        "META": "Neutral",
        "MSFT": "Neutral",
        "NFLX": "Neutral",
        "NOW": "Bullish/Overvalued",
        "NVDA": "Neutral",
        "ORCL": "Neutral",
        "QCOM": "Neutral",
        "SAP": "Neutral",
        "SMSN": "Neutral",
        "TCEHY": "Neutral",
        "TSLA": "Slightly Overvalued",
        "TSM": "Neutral",
    }
    return company_sentiment.get(content["ticker"], "No sentiment available")

In [4]:
get_company_information = FunctionDeclaration(
    name="get_company_information",
    description="Retrieves financial performance to provide an overview for a company.",
    parameters={
        "type": "object",
        "properties": {
            "ticker": {
                "type": "string",
                "description": "Stock ticker for a given company",
            }
        },
        "required": ["ticker"],
    },
)

get_stock_price = FunctionDeclaration(
    name="get_stock_price",
    description="Only returns the current stock price (in dollars) for a company.",
    parameters={
        "type": "object",
        "properties": {
            "ticker": {
                "type": "integer",
                "description": "Stock ticker for a company",
            }
        },
        "required": ["ticker"],
    },
)

get_company_news = FunctionDeclaration(
    name="get_company_news",
    description="Get the latest news headlines for a given company.",
    parameters={
        "type": "object",
        "properties": {
            "ticker": {
                "type": "string",
                "description": "Stock ticker for a company.",
            }
        },
        "required": ["ticker"],
    },
)

get_company_sentiment = FunctionDeclaration(
    name="get_company_sentiment",
    description="Returns the overall market sentiment for a company.",
    parameters={
        "type": "object",
        "properties": {
            "ticker": {
                "type": "string",
                "description": "Stock ticker for a company",
            },
        },
        "required": ["ticker"],
    },
)

In [5]:
tools = Tool(
    function_declarations=[
        get_company_information,
        get_stock_price,
        get_company_news,
        get_company_sentiment,
    ]
)

tool_config = ToolConfig(
    function_calling_config=ToolConfig.FunctionCallingConfig(
        mode=ToolConfig.FunctionCallingConfig.Mode.ANY,
        allowed_function_names=[
            "get_company_information",
            "get_stock_price",
            "get_company_news",
            "get_company_sentiment",
        ],
    )
)

function_map = {
    "get_company_information": get_company_information_api,
    "get_stock_price": get_stock_price_api,
    "get_company_news": get_company_news_api,
    "get_company_sentiment": get_company_sentiment_api,
}

### Seed Prompts

In [21]:
def stage1_prompt(question):
    return f"""
    Determine the best tool to use based on the question.
    Question: {question}

    """

def stage2_prompt(question, tool_call_response):
    return f"""
    Answer the question using the tool call response.
    Question: {question}
    Tool call response: {tool_call_response}
    """

### Stage 1 Inference

In [None]:
def stage1_inference(prompt):
    response = lm.generate_content(
        prompt,
        tools=[tools],
        tool_config=tool_config
    )
    function_call = response.candidates[0].content.parts[0].function_call
    function_name = function_call.name
    #print(f"Function name: {function_name}")
    if function_name in function_map:
        function_args = function_call.args 
        #print(f"Function args: {function_args}")
        api_response = function_map[function_name](function_args)
    return api_response

question = "What is the current sentiment of Apple?"
tool_call_response = stage1_inference(stage1_prompt(question))
print(tool_call_response)

### Stage 2 Inference

In [None]:
def stage2_inference(question, tool_call_response):
    response = lm.generate_content(
        stage2_prompt(question, tool_call_response),
    )
    return response.candidates[0].content.parts[0].text

stage2_inference(question, tool_call_response)

### Multi stage inference

In [None]:
def multi_stage_inference(stage_1_prompt, question):
    tool_call_response = stage1_inference(stage_1_prompt)
    answer = stage2_inference(question, tool_call_response)
    return answer

multi_stage_inference(stage1_prompt(question), question)

### Setup the train and test datasets







In [10]:
import dspy
import pandas as pd
from sklearn.model_selection import train_test_split
INPUT_DATA_FILE_URI = 'gs://github-repo/prompts/prompt_optimizer/qa_tool_calls_dataset.jsonl'
df = pd.read_json(INPUT_DATA_FILE_URI, lines=True)

In [11]:
trainset, testset = train_test_split(
    df, test_size=0.8, random_state=8
)

In [None]:
trainset.head()

### Define Evaluation Metric/Establish Baseline







In [None]:
def evaluate_stage1(stage_1_prompt_fn, dataset):
    scores = []
    for _, row in dataset.iterrows():
        pred = stage1_inference(stage_1_prompt_fn(row['question']))
        scores.append(pred == row['tool_call_response'])
        print(pred)
        print(row['tool_call_response'])
    return sum(scores) / len(scores)

evaluate_stage1(stage1_prompt, testset)

In [15]:
dspy.configure(lm=dspy.LM('vertex_ai/gemini-2.0-flash-001'))
def ai_accuracy(answer, pred, trace=None):
    """Use LLM to check if the predicted answer contains the same information as the ground truth answer."""
    question = f"Does the predicted answer contain the information in the ground truth answer? It is ok if the predicted answer is a superset and contains more information than the ground truth answer."
    signature = "question: str, ground_truth_answer: str, predicted_answer: str -> match: bool"
    score =dspy.Predict(signature)(
       question=question, 
       ground_truth_answer=answer, 
       predicted_answer=pred)

    return score.match

In [None]:
def evaluate_multistage(stage_1_prompt_fn, dataset):
    scores = []
    for _, row in dataset.iterrows():
        pred = multi_stage_inference(stage_1_prompt_fn(row['question']), row['question'])
        score = ai_accuracy(row['answer'], pred)
        scores.append(score)
        print(pred)
        print(row['answer'])
    return sum(scores) / len(scores)

evaluate_multistage(stage1_prompt, testset)

### Evaluate Optimized

Refer to [the official VAPO tool calling notebook](https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/prompts/prompt_optimizer/vertex_ai_prompt_optimizer_sdk_tool_calling.ipynb) for the process to arrive at the optimized prompt.

In [None]:
def stage1_prompt_optimized(question):
    return f"""
    To provide the most accurate response to the given question, determine and employ the most suitable tools.
    Question: {question}

    """
evaluate_stage1(stage1_prompt_optimized, testset)

In [None]:
evaluate_multistage(stage1_prompt_optimized, testset)