<a href="https://colab.research.google.com/github/tracydo99/bus118/blob/main/Ad_Optimization_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [48]:
!pip install langgraph langchain langchain_openai pydantic pandas -qq

In [49]:
import os
import operator
import pandas as pd
from typing import TypedDict, Annotated, List
from langchain_core.messages import BaseMessage, HumanMessage, ToolMessage
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.tools import tool
from langgraph.graph import StateGraph, END
from google.colab import userdata
from pydantic import BaseModel, Field

In [50]:
try:
    os.environ["OPENAI_API_KEY"] = userdata.get("OPENAI_API_KEY")
    print("OPENAI_API_KEY loaded from Colab Secrets.")
except Exception as e:
    print(f"WARNING: Could not load OPENAI_API_KEY from Colab Secrets. Please ensure it is set. Error: {e}")

OPENAI_API_KEY loaded from Colab Secrets.


In [51]:
import pandas as pd

# Load the data
df = pd.read_csv("ad_performance_data.csv")

# Print the first 5 rows
display(df.head())

Unnamed: 0,date,channel,spend,impressions,clicks,conversions
0,2025-10-01,Search,97.49,5323,176,7
1,2025-10-01,Social,99.01,14284,117,4
2,2025-10-01,Display,103.5,34737,36,0
3,2025-10-02,Search,106.65,5383,181,8
4,2025-10-02,Social,84.25,12362,120,3


In [52]:
# 1. Define the Agent State
class AgentState(TypedDict):
    """
    Represents the state of our ad optimization agent.
    - messages: A list of messages/history.
    - current_data: The full ad campaign data.
    - next_action: The recommended action from the analysis.
    - decision_log: A list to track all optimization decisions.
    """
    messages: Annotated[List[BaseMessage], operator.add]
    current_data: pd.DataFrame
    next_action: str
    decision_log: Annotated[List[dict], operator.add]

In [53]:
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0)

In [57]:
# --- Data Reading and Processing ---

def read_ad_data(file_path: str = "ad_performance_data.csv") -> pd.DataFrame:
    """Reads the ad performance data from the specified CSV file."""
    try:
        df = pd.read_csv(file_path)
        df['date'] = pd.to_datetime(df['date'])
        required_cols = ['date', 'channel', 'spend', 'impressions', 'clicks', 'conversions']
        if not all(col in df.columns for col in required_cols):
             raise ValueError(f"CSV missing required columns. Found: {df.columns.tolist()}")
        print(f"Successfully read data from {file_path}. Total records: {len(df)}")
        return df
    except FileNotFoundError:
        print(f"ERROR: File not found at {file_path}. Generating mock data.")
        return create_mock_data()
    except Exception as e:
        print(f"An error occurred reading or processing the CSV: {e}. Generating mock data.")
        return create_mock_data()

def create_mock_data():
    """Creates mock data if the CSV is not found."""
    dates = pd.date_range(start="2024-10-01", periods=30)
    data = []
    import numpy as np
    for date in dates:
        for channel in ['Search', 'Social', 'Display']:
            spend = np.random.uniform(50, 150)
            impressions = np.random.randint(10000, 50000)
            clicks = impressions * np.random.uniform(0.005, 0.015)
            # Give Search a slight edge in conversions for initial test
            conv_rate = np.random.uniform(0.01, 0.05) if channel == 'Search' else np.random.uniform(0.001, 0.02)
            conversions = clicks * conv_rate
            data.append([date, channel, spend, impressions, clicks, conversions])
    df = pd.DataFrame(data, columns=['date', 'channel', 'spend', 'impressions', 'clicks', 'conversions'])
    return df

def process_data(df: pd.DataFrame) -> tuple[List[dict], float]:
    """
    Computes key performance metrics (CVR, CTR, CPA) for the most recent day.
    """
    latest_date = df['date'].max()
    latest_data = df[df['date'] == latest_date].copy()

    # Compute Metrics
    latest_data['CVR'] = latest_data.apply(lambda row: row['conversions'] / row['clicks'] if row['clicks'] > 0 else 0, axis=1)
    latest_data['CTR'] = latest_data.apply(lambda row: row['clicks'] / row['impressions'] if row['impressions'] > 0 else 0, axis=1)
    latest_data['CPA'] = latest_data.apply(lambda row: row['spend'] / row['conversions'] if row['conversions'] > 0 else float('inf'), axis=1)

    # Format data for LLM
    analysis_data = latest_data[['channel', 'spend', 'CVR', 'CTR', 'CPA']].copy()
    analysis_data['spend_percent'] = analysis_data['spend'] / analysis_data['spend'].sum()

    print(f"--- Data for latest date ({latest_date.date()}) processed. ---")

    return analysis_data.to_dict('records'), latest_data['spend'].sum()

In [58]:
# --- Tooling for Budget Allocation ---

class BudgetShift(BaseModel):
    """Input for proposing a budget shift, adhering to all guardrails."""
    channel_shifts: List[str] = Field(
        description="A list of proposed budget changes. Each entry must be a string in the format: 'Channel: +X% (Reason)' or 'Channel: -Y% (Reason)'. X and Y must be between 0 and 20. The total sum of all shifts must equal 0%. The *reason* is mandatory, citing CVR or CTR."
    )
    rationale: str = Field(
        description="A brief, overall rationale for the budget shifts. State which channel(s) are favored and why (e.g., higher CVR/stable CPA)."
    )

@tool("propose_budget_shift", args_schema=BudgetShift)
def propose_budget_shift(channel_shifts: List[str], rationale: str) -> str:
    """
    Proposes new daily budget allocations across channels (Search, Social, Display)
    based on performance data to maximize conversions or CTR. Logs the decision.
    """
    # This tool simulates the successful application of the proposed budget shift.
    return f"Budget Shift Proposed and Logged. Rationale: {rationale}"

In [60]:
# 3. Graph Nodes and Functions

def fetch_and_process_data(state: AgentState) -> dict:
    """Fetches and processes the latest campaign data."""
    print("--- FETCHING DATA ---")

    full_df = read_ad_data()
    processed_data, total_budget = process_data(full_df)

    analysis_prompt = (
        "**Ad Performance Data (Latest Day):**\n"
        f"{processed_data}\n\n"
        f"**Total Daily Budget:** ${total_budget:.2f}\n\n"
        "**Optimization Goal:** Reallocate the budget across the channels (Search, Social, Display) "
        "to **maximize total conversions** by prioritizing channels with the highest **CVR** or **CTR**.\n"
        "**Heuristic:** Compute each channel’s CVR and CTR. Shift **+10% to +20%** budget toward the top performer(s) while keeping a minimum floor for others.\n"
        "**Guardrails:**\n"
        "1. Cap per-channel budget change at **±20%** of the *current* channel's budget.\n"
        "2. Never allocate **0%** to a channel for more than 2 consecutive days (ensure continuous spend).\n"
        "3. The sum of all shifts must be **0%** (i.e., balanced reallocation).\n\n"
        "**Task:** Use the `propose_budget_shift` tool. Provide a list of budget shifts as percentages (e.g., ['Search: +15% (Highest CVR)', 'Social: -10% (Lowest CVR)', 'Display: -5% (Medium CTR)']) and your overall rationale."
    )
    new_messages = [HumanMessage(content=analysis_prompt)]

    # Store the full data for context and initialize log
    return {"messages": new_messages, "current_data": full_df, "next_action": "propose_shift", "decision_log": []}

def agent_reasoning(state: AgentState) -> dict:
    """The LLM reasons and decides the next step (tool call or final answer)."""
    print("--- AGENT REASONING ---")

    llm_with_tools = llm.bind_tools([propose_budget_shift])

    system_prompt = (
        "You are an expert Ad Budget Optimization Agent. Your goal is to analyze the provided performance data and decide the optimal **budget reallocation** using the `propose_budget_shift` tool. "
        "You **must** use the `propose_budget_shift` tool. Ensure your proposed shifts meet all guardrails (max +/-20% per channel, sum of shifts is 0%)."
    )

    prompt = ChatPromptTemplate.from_messages([
        ("system", system_prompt),
        ("placeholder", "{messages}")
    ])

    chain = prompt | llm_with_tools
    response = chain.invoke(state)

    if response.tool_calls:
        return {"messages": [response], "next_action": "call_tool"}
    else:
        # If the LLM somehow skipped the tool, treat its response as the final answer
        return {"messages": [response], "next_action": "FINISH"}

def execute_tools(state: AgentState) -> dict:
    """Executes the tool call decided by the agent_reasoning step and logs the decision."""
    print("--- EXECUTING TOOL ---")

    tool_calls = state["messages"][-1].tool_calls
    tool_results = []

    for call in tool_calls:
        tool_name = call["name"]
        tool_args = call["args"]
        tool_call_id = call["id"]

        if tool_name == "propose_budget_shift":
            try:
                # Execute the tool function
                result = propose_budget_shift.invoke(tool_args)

                # --- LOGGING THE DECISION ---
                log_entry = {
                    "date": state["current_data"]["date"].max().date().isoformat(),
                    "action": "Budget Shift Proposal",
                    "details": tool_args["channel_shifts"],
                    "rationale": tool_args["rationale"],
                }

                # Format the tool result as a ToolMessage
                tool_results.append(ToolMessage(
                    content=result,
                    tool_call_id=tool_call_id
                ))

                # Set next action to FINISH after successful execution and logging
                return {"messages": tool_results, "decision_log": [log_entry], "next_action": "FINISH"}

            except Exception as e:
                # Handle tool execution errors
                tool_results.append(ToolMessage(
                   content=f"Error executing tool {tool_name}: {e}",
                   tool_call_id=tool_call_id
                ))
                # If tool execution fails, send the error back to reasoning for analysis
                return {"messages": tool_results, "next_action": "agent_reasoning"}

    return {"messages": tool_results, "next_action": "FINISH"} # Should not be reached

def decide_next_step(state: AgentState) -> str:
    """Conditional edge: decides whether to continue analysis or finish."""
    # This function returns a string key that matches a key in the conditional edge map
    next_action = state.get("next_action")

    if next_action == "call_tool":
        return "call_tool"
    elif next_action == "FINISH":
        # Returns the string "FINISH", which is mapped to the END state
        return "FINISH"
    else:
        # Default fallback (e.g., after a tool failure message)
        return "agent_reasoning"

In [61]:
# 4. Build the LangGraph Workflow
workflow = StateGraph(AgentState)

# Add nodes
workflow.add_node("fetch_data", fetch_and_process_data)
workflow.add_node("agent_reasoning", agent_reasoning)
workflow.add_node("execute_tools", execute_tools)

# Set the start point
workflow.set_entry_point("fetch_data")

# Add edges
workflow.add_edge("fetch_data", "agent_reasoning")

# Conditional edge from reasoning (decides if it's a tool call or the end)
workflow.add_conditional_edges(
    "agent_reasoning",
    decide_next_step,
    {"call_tool": "execute_tools", "FINISH": END}
)

# Conditional edge from tool execution (decides if it's finished or needs to re-reason)
workflow.add_conditional_edges(
    "execute_tools",
    decide_next_step,
    {
        "agent_reasoning": "agent_reasoning",  # Go back to reasoning (e.g., if tool failed)
        "FINISH": END                          # Finish the graph run
    }
)

# Compile the graph
app = workflow.compile()

In [62]:
# 5. Run the Agent
print("\n" + "="*50)
print("--- STARTING AD OPTIMIZATION AGENT RUN ---")
print("="*50)

# The agent runs autonomously until it hits the END node
final_state = app.invoke(
    {"messages": [], "current_data": pd.DataFrame(), "next_action": "start", "decision_log": []},
    config={"recursion_limit": 50})


--- STARTING AD OPTIMIZATION AGENT RUN ---
--- FETCHING DATA ---
Successfully read data from ad_performance_data.csv. Total records: 60
--- Data for latest date (2025-10-20) processed. ---
--- AGENT REASONING ---
--- EXECUTING TOOL ---


In [63]:
# 6. Print Final Result and Log
print("\n" + "="*50)
print("AGENT FINAL RECOMMENDATION & SUMMARY:")

# The final message is the summary after the tool execution
final_message = final_state["messages"][-1].content
print(final_message)

print("\n--- DECISION LOG ---")
if final_state["decision_log"]:
    for entry in final_state["decision_log"]:
        print(f"Date: {entry['date']}")
        print(f"Action: {entry['action']}")
        print(f"Rationale: {entry['rationale']}")
        print(f"Shifts: {entry['details']}")
        print("-" * 20)
else:
    print("No budget shift was successfully proposed and logged.")

print("="*50)


AGENT FINAL RECOMMENDATION & SUMMARY:
Budget Shift Proposed and Logged. Rationale: The budget is being shifted towards Search due to its highest CVR, which indicates better performance in converting clicks to actions. Social is being reduced due to its lower CVR, and Display is also reduced as it has not generated any conversions.

--- DECISION LOG ---
Date: 2025-10-20
Action: Budget Shift Proposal
Rationale: The budget is being shifted towards Search due to its highest CVR, which indicates better performance in converting clicks to actions. Social is being reduced due to its lower CVR, and Display is also reduced as it has not generated any conversions.
Shifts: ['Search: +15% (Highest CVR)', 'Social: -10% (Lowest CVR)', 'Display: -5% (No Conversions)']
--------------------
