In [None]:
from google.colab import userdata
GEMINI_VERTEX_API_KEY = userdata.get('VERTEX_API_KEY')

In [None]:
debug_enabled = True

In [None]:
import base64
import mimetypes
import logging

# Helper function to read and encode image
def image_to_base64(img_path):
    with open(img_path, "rb") as img_file:
        return base64.b64encode(img_file.read()).decode('utf-8')

# Helper function to encode local file to Base64 Data URL
def get_image_data_url(image_path):
    # Guess the mime type (e.g., image/png, image/jpeg) based on file extension
    mime_type, _ = mimetypes.guess_type(image_path)
    if mime_type is None:
        mime_type = "image/png" # Default fallback

    encoded_string = image_to_base64(image_path)

    # Construct the Data URL
    return f"data:{mime_type};base64,{encoded_string}"


In [None]:
import sys
import logging

logger = logging.getLogger()
logger.setLevel(logging.INFO)

for h in logger.handlers[:]:
    logger.removeHandler(h)

file_handler = logging.FileHandler("output.log", encoding="utf-8",mode="w")
file_handler.setLevel(logging.INFO)
file_handler.setFormatter(logging.Formatter("%(asctime)s - %(message)s"))

console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter("%(message)s"))

logger.addHandler(file_handler)
logger.addHandler(console_handler)

class LoggerWriter:
    def __init__(self, level_func):
        self.level_func = level_func
    def write(self, message):
        message = message.strip()
        if message:
            self.level_func(message)
    def flush(self):
        pass

sys.stdout = LoggerWriter(logger.info)
sys.stderr = LoggerWriter(logger.error)


In [None]:
import os
image_folder_path = "./receipts"
image_data_urls = []
for image_path in os.listdir(image_folder_path):
    if os.path.splitext(image_path.lower())[1] in [".jpg", ".jpeg", ".png", ".bmp"]:
        image_data_urls.append(get_image_data_url(os.path.join(image_folder_path, image_path)))
print(f"Found {len(image_data_urls)} images for processing.")

In [None]:
from langchain_google_genai import ChatGoogleGenerativeAI
llm = ChatGoogleGenerativeAI(
    model="gemini-2.5-flash",
    api_key=GEMINI_VERTEX_API_KEY, # Ensure this key is set in Colab secrets
    temperature=0,
    vertexai=True
)

In [None]:
from dataclasses import fields
import json
from typing import List, Optional
from langchain_core.messages import BaseMessage
import re


def safe_invoke(model, cls, input: str, config: dict = None, stop: list = None, retries: int = 3) -> str:
    for attempt in range(1, retries + 1):
        try:
            response = model.invoke(input, config=config, stop=stop)
            content = response.content.strip()
            content = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip())
            data = json.loads(content)
            field_names = {f.name for f in fields(cls)}
            filtered_data = {k: v for k, v in data.items() if k in field_names}
            missing = field_names - filtered_data.keys()
            if missing:
                raise ValueError(f"Missing fields: {missing}")
            if debug_enabled:
                print(f"[safe_invoke] response: {json.dumps(filtered_data, indent=2, ensure_ascii=False)}")
            return filtered_data
        except json.JSONDecodeError as e:
            print(f"[safe_invoke:{attempt}] JSON decode error: {str(e)}. Response content: {response.content}")
        except Exception as e:
            print(f"[safe_invoke:{attempt}] Error during LLM invocation: {str(e)}")
    return {f.name: None for f in fields(cls)}

def safe_invoke_v2(
    model,
    cls,
    messages: List[BaseMessage],
    config: Optional[dict] = None,
    stop: Optional[list] = None,
    retries: int = 3,
) -> dict:
    if not isinstance(messages, list):
        raise TypeError("safe_invoke_v2 expects `messages` to be a list of BaseMessage")
    field_names = {f.name for f in fields(cls)}
    for attempt in range(1, retries + 1):
        try:
            response = model.invoke(messages, config=config, stop=stop)
            # response.content should still be text (JSON string)
            content = response.content.strip()
            content = re.sub(r"^```(?:json)?\s*|\s*```$", "", content.strip())
            data = json.loads(content)
            filtered_data = {k: v for k, v in data.items() if k in field_names}
            missing = field_names - filtered_data.keys()
            if missing:
                raise ValueError(f"Missing fields: {missing}")
            if debug_enabled:
                print(
                    f"[safe_invoke_mm_v2] response:\n"
                    f"{json.dumps(filtered_data, indent=2, ensure_ascii=False)}"
                )
            return filtered_data
        except json.JSONDecodeError as e:
            print(
                f"[safe_invoke_mm_v2:{attempt}] JSON decode error: {str(e)}\n"
                f"Raw content:\n{content}"
            )
        except Exception as e:
            print(f"[safe_invoke_mm_v2:{attempt}] Error during LLM invocation: {str(e)}")
    # fallback: all fields set to None
    return {f.name: None for f in fields(cls)}

def describe_schema(cls):
    lines = ["Your response **must** be a JSON object with the following format:"]
    for f in fields(cls):
        if f.name.startswith("template_"):
            continue
        desc = f.metadata.get("description", "")
        lines.append(f'- {f.name} ({f.type.__name__}): {desc}')
    if hasattr(cls, "output_format"):
        lines.append(f"Here is an example of the expected output format:{cls.output_format}")
    return "\n".join(lines)

Define Tool functions

In [None]:
import ast

def safe_exec(code: str, context: dict):
    tree = ast.parse(code)
    # for node in ast.walk(tree):
    #     if not isinstance(node, ALLOWED_NODES):
    #         raise ValueError(f"Illegal node: {type(node)}")
    local_env = {}
    exec(code, {}, local_env)
    result = local_env["compute"](context)
    print(f"[safe_exec]---------------\n{context}\nresult:\n{json.dumps(result,indent=4)}\n----------------")
    return result


Define Agent State Schema

In [None]:
from typing import TypedDict, List
class AgentState(TypedDict):
    next_action: str
    plan_desc: str
    user_query: str
    images_data: List[str] # images_data_urls["data:{mime_type};base64,{encoded_string}",...]
    receipt_results: List[dict]
    answer: str
    
# 1. Define State for the Subgraph (Single Image Processing)
class ImageState(TypedDict):
    plan_desc: str
    user_query: str
    analysis_code: str  # Add this field
    image_data: str # "data:{mime_type};base64,{encoded_string}"
    identification: dict
    sub_result: str


Define LLM response format

In [None]:
from dataclasses import dataclass, field

@dataclass
class PlanningOutput:
    thought: str = field(metadata={"description": "One sentence thought process."})
    next_action: str = field(metadata={"description": "The next action phase to take. Must be one of these words: execute, interrupt."})
    plan_desc: str = field(metadata={"description": "A one sentence description about how to complete the task. If action is 'interrupt', leave 'None'."})
    output_format = """
{
  "thought": <string>,
  "next_action": "<string>,
  "plan_desc": <string>
}
"""

@dataclass
class IdentifyOutput:
    items: List[dict] = field(metadata={"description": "List of items with name, quantity, line_subtotal, discount_amount."})
    subtotal: float = field(metadata={"description": "The subtotal amount."})
    rounding: float = field(metadata={"description": "The rounding adjustment amount."})
    output_format = """
{
  "items": [
    {
      "name": <string>,
      "quantity": <int>,
      "line_subtotal": <float>,
      "discount_amount": <negative float>,
    },
    ...
  ],
  "subtotal": <float>,
  "rounding": <negative float>
}
"""

@dataclass
class AnalyzeOutput:
    thought: str = field(metadata={"description": "One sentence thought process."})
    code: str = field(
        metadata={"description": "Python code string defining def compute(receipt): answering the query."}
    )
    output_format = """
{
  "thought": <string>
  "code": <string>
}
"""



Define Graph nodes

1. Subgraph nodes for Image processing

In [None]:
from langgraph.types import Command
from langchain_core.messages import HumanMessage

def identify_node(state: ImageState):
    print("[Identify] Identifying image content...")
    IDENTIFY_PROMPT = """
You are the identify module of an agent that processes receipt images and handle questions.
[Task]
Given an image of a retail receipt, extract ONLY the information required for bookkeeping.
[Extraction Rules]
- The price on the right of item name is `line_subtotal`, it is the amount **BEFORE** discount for an item. You MUST NOT add it with discount amount mistakenly
- Discounts must be recorded as negative amounts.
- If no discount for one item, set discount_amount to 0.
- If quantity is not explicitly shown, set quantity = 1.
- Ignore loyalty points, card numbers, device numbers, cashier info, store address, timestamps.
- Amounts must be numeric values without currency symbols.
[Warning]
{schema}
Do NOT include explanations, comments, or any text outside the JSON.
Do NOT infer or guess missing information.
If a value is not explicitly shown on the receipt, set it to null.
"""
    messages = [
        HumanMessage(
            content=[
                {
                    "type": "text",
                    "text": IDENTIFY_PROMPT.format(
                        schema=describe_schema(IdentifyOutput)
                    ),
                },
                {
                    "type": "image_url",
                    "image_url": {
                        "url": state["image_data"]
                    },
                },
            ]
        )
    ]
    response = safe_invoke_v2(llm, IdentifyOutput, messages=messages)
    return {"identification": response}

def execute_node(state: ImageState):
    print("[Execute] Extracting data from receipt...")
    identification = state["identification"]
    analysis_code = state.get("analysis_code", "")
    try:
        exec_result = safe_exec(analysis_code, identification)
    except Exception as e:
        print(f"[Analyze] Error executing code: {e}")
        exec_result = None

    return {"sub_result": exec_result}

def validate_node(state: ImageState):
    print("[Validate] Validating identification and analysis...")
    # Logic: Check if total matches sum of items
    next_action = "" # identify \ analyze \ 
    if next_action == "identify":
        return Command(
            goto="identify"
        )
    elif next_action == "analyze":
        return Command(
            goto="analyze"
        )
    return {"sub_result": state["sub_result"]}


In [None]:
from langgraph.graph import StateGraph, END

receipt_builder = StateGraph(ImageState)
receipt_builder.add_node("identify", identify_node)
receipt_builder.add_node("execute", execute_node)
receipt_builder.add_node("validate", validate_node)
receipt_builder.add_edge("identify", "execute")
receipt_builder.add_edge("execute", END)
# receipt_builder.add_edge("analyze", "validate")
# receipt_builder.add_edge("validate", END) # if validate node doesn't use goto, it goes to END
receipt_builder.set_entry_point("identify")

receipt_processing_graph = receipt_builder.compile()

2. Agent graph nodes

In [None]:
def plan_node(state: AgentState):
    print("[Plan] deciding next action and plan description")
    prompt = f"""
You are the plan module of an agent that processes receipt images and handles questions.
[Query from User]
{state["user_query"]}
[Ability Boundary]
You can only handle these queries:
1. Total money spent for the receipts.
2. Amount of money should have been spent without the discount.
[Task]
Decide the next action and a brief plan description to handle user's query.
If the query is irrelevant to the ability boundary, set action to "interrupt" and plan_desc to "None".
[Warning]
{describe_schema(PlanningOutput)}
"""
    response = safe_invoke(llm, PlanningOutput, prompt)
    return {
        "next_action": response["next_action"],
        "plan_desc": response["plan_desc"]
    }
    
async def analyze_node(state: AgentState):
    print("[Analyze] Generating analysis code...")
    user_query = state.get("user_query", "")
    plan_desc = state.get("plan_desc", "")

    # Generate the analysis code ONCE for all receipts
    prompt = f"""
You are the analysis module of an agent that processes receipt images and handle questions.
[Task]
Generate a valid Python code snippet that defines a function `compute(receipt)` returning a numeric result.
The function will receive the receipt data as a dictionary argument.

[User Query]
{user_query}

[Basic Plan]
{plan_desc}

[Receipt Data Structure]
The `receipt` argument is a dictionary with the following fields:
- `items` (list[dict]): A list of items purchased. Each item dict has:
    - `name` (str): Name of the item.
    - `quantity` (int): Quantity purchased.
    - `line_subtotal` (float): The price of the line item **BEFORE** any discount.
    - `discount_amount` (float): The discount applied to this item (negative value).
- `subtotal` (float): The total sum before rounding.
- `rounding` (float): The rounding adjustment.

[Requirements]
{describe_schema(AnalyzeOutput)}
- The code must define exactly one function `compute(receipt)` taking `receipt` as input.
- The function must return a single numeric value (int or float).
- Use standard Python operations (sum, list comprehension, basic math).    
- No imports allowed.
- Do not include comments, docstrings, or explanations.

[Examples]
User query: "How much would I have had to pay without the discount?"
def compute(receipt):
    # Sum line_subtotal (price before discount) for all items
    return sum(item.get('line_subtotal', 0) for item in receipt.get('items', []))

User query: "How much money did I spend in total for these bills?"
def compute(receipt):
    return receipt.get('subtotal', 0) + receipt.get('rounding', 0)
"""
    response = safe_invoke(llm, AnalyzeOutput, prompt)
    analysis_code = response["code"]
    print(f"[Analyze] Generated Analysis Code:\n{analysis_code}")

    print("[Analyze] Processing images in parallel...")
    images = state.get("images_data", [])
    # Pass the generated code to each subgraph instance
    inputs = [
        {
            "image_data": img, 
            "user_query": state["user_query"], 
            "plan_desc": state["plan_desc"],
            "analysis_code": analysis_code
        } 
        for img in images
    ]
    # Run the subgraph in parallel for all images
    # .abatch executes the graph for each input concurrently
    results = await receipt_processing_graph.abatch(inputs)
    # Extract final results from the subgraph outputs
    receipt_data = [r["sub_result"] for r in results if "sub_result" in r]
    print(f"  Received {len(receipt_data)} processed results.")
    total_sum = sum(float(r.get("sub_result", 0)) if r.get("sub_result") is not None else 0 for r in results)
    print(f"  Final Integration: Total Sum = {total_sum}")
    return {"receipt_results": receipt_data,
            "answer": total_sum}
  
def end_node(state):
  if state["next_action"] == "interrupt":
      print("The query is invalid.")
  else:
      print("Completed successfully.")
  return {}


In [None]:
agent_builder = StateGraph(AgentState)
agent_builder.add_node("plan", plan_node)
agent_builder.add_node("analyze", analyze_node)
agent_builder.add_node("end", end_node)
agent_builder.set_entry_point("plan")

def routing_plan(state):
    if state["next_action"] == "interrupt":
        return "end"
    return "analyze"

agent_builder.add_conditional_edges("plan", routing_plan)
agent_builder.add_edge("analyze", "end")

agent_graph = agent_builder.compile()

In [None]:
def test_query(answer, ground_truth_costs):
    # Convert string to float if necessary
    if isinstance(answer, str):
        answer = float(answer)

    # Calculate the ground truth sum once for clarity
    expected_total = sum(ground_truth_costs)

    # Check if the answer is within +/- $2 of the expected total
    assert abs(answer - expected_total) <= 2

Run the following code block to evaluate query 1:
> How much money did I spend in total for these bills?

In [None]:
user_input = {
    "user_query": "How much money did I spend in total for these bills?",
    "images_data": image_data_urls
}
result_1 = await agent_graph.ainvoke(user_input, config={"max_concurrency": 3})
query_1_costs = [394.7, 316.1, 140.8, 514.0, 102.3, 190.8, 315.6] # do not modify this
print(result_1['receipt_results'])
query1_answer = result_1['answer']
test_query(query1_answer, query_1_costs)

Run the following code block to evaluate query 2:
> How much would I have had to pay without the discount?

In [None]:
user_input = {
    "user_query": "How much would I have had to pay without the discount?",
    "images_data": image_data_urls
}
result_2 = await agent_graph.ainvoke(user_input, config={"max_concurrency": 3})
query_2_costs = [480.20, 392.20, 160.10, 590.80, 107.70, 221.20, 396.00] # do not modify this
print(result_2['receipt_results'])
query2_answer = result_2['answer']
test_query(query2_answer, query_2_costs)