In [9]:
import json
import torch
import numpy as np
import pandas as pd
from itertools import chain
from pathlib import Path
from gpa.datasets.attribution import PriceAttributionDataset
from products_client import get_product_client
from tqdm import tqdm
from gpa.datasets.attribution import DetectionGraph
from google import genai
from google.genai import types


sabre = get_product_client(
    url="https://products.qa.deliciousai.com/graphql/",
    client_id="OlDSXQZzZmr8ifr4ng4xXAc044TGD7TcgksIqXaT",
    client_secret="t6Ci92eN3eUzbAbsjmPw5VpYWbPTEfqT8Kyauh3yjhLkCNuIKFNmCMwkV4MhFOT1ZkFVDBUmtIWxzUQwt6Mvl0OzQdlVRkebINgXMabdHNmoJ8qsw415YMrq5oVQmG53",
)
gemini = genai.Client(vertexai=True, project="dai-ultra", location="us-central1")

In [8]:
generate_content_config = types.GenerateContentConfig(
    temperature=1,
    top_p=0.95,
    max_output_tokens=8192,
    response_modalities=["TEXT"],
    response_mime_type="application/json",
)

In [None]:
cache = {}

In [None]:
for upc in tqdm(
    products_df.ml_label_name.unique(), total=len(products_df.ml_label_name.unique())
):
    if upc not in cache:
        try:
            product_response = sabre.query.product(upc=upc)
            cache[upc] = product_response.name
        except:
            cache[upc] = None

with open("../data/text-matching/products.json", "w") as f:
    json.dump(cache, f, indent=2)

In [4]:
with open("../data/text-matching/products.json", "r") as f:
    cache = json.load(f)

In [42]:
dataset_dir = Path("../data/price-graphs-ii/test")
dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(
    dataset_dir / "raw" / "product_boxes.csv",
    index_col="attributionset_id",
    dtype={"ml_label_name": str},
)
price_df = pd.read_csv(
    dataset_dir / "raw" / "price_boxes.csv",
    index_col="price_bbox_id",
)
with open("../data/text-matching/price_tag_product_metadata.json", "r") as f:
    price_product = json.load(f)

for bbox_id, product_metadata in price_product.items():
    price_df.loc[bbox_id, "product_metadata"] = (
        product_metadata["price_product_metadata"] or "NULL"
    )

price_df = price_df.set_index("attributionset_id")

In [56]:
def get_products_dict(
    scene_graph: DetectionGraph,
    scene: pd.DataFrame,
    product_cache: dict[str, dict[str, str]] | None = None,
):
    products = []
    grouped = dict(
        pd.DataFrame(zip(scene.ml_label_name, scene_graph.product_indices.tolist()))
        .groupby(0)
        .agg(list)
        .itertuples(index=True, name=None)
    )
    for sku, indices in grouped.items():
        product = {"sku": sku}
        if product_cache is not None and product_cache[sku] is not None:
            product["name"] = product_cache[sku]
        products.append(product)
    return products, grouped


def get_prices_dict(scene_graph: DetectionGraph, scene_prices: pd.DataFrame):
    prices = []
    grouped = dict(
        pd.DataFrame(
            zip(
                scene_prices.price_type,
                scene_prices.price_contents,
                scene_prices.product_metadata,
                scene_graph.price_indices.tolist(),
            )
        )
        .groupby([0, 1, 2])
        .agg(list)
        .itertuples(index=True, name=None)
    )
    grouped_to_return = dict(
        pd.DataFrame(
            zip(
                scene_prices.price_type,
                scene_prices.price_contents,
                scene_graph.price_indices.tolist(),
            )
        )
        .groupby([0, 1])
        .agg(list)
        .itertuples(index=True, name=None)
    )
    for (price_type, price_contents, product_metadata), indices in grouped.items():
        pr = {
            "price_type": price_type,
            "contents": price_contents,
        }
        if product_metadata != "NULL":
            pr["product_metadata"] = product_metadata
        prices.append(pr)
    return prices, grouped_to_return


def call_llm(prompt: str, model: str) -> dict:
    raw = gemini.models.generate_content(
        model=model,
        contents=[types.Content(role="user", parts=[types.Part(text=prompt)])],
        config=generate_content_config,
    )
    cleaned = raw.text.strip().removeprefix("```json").removesuffix("```").strip()
    return json.loads(cleaned)


def get_llm_prediction(graph: DetectionGraph, model: str) -> torch.Tensor:
    products, product_groups = get_products_dict(
        graph, scene=products_df.loc[[graph.graph_id]], product_cache=cache
    )
    prices, prices_grouped = get_prices_dict(
        graph, scene_prices=price_df.loc[[graph.graph_id]]
    )
    prompt = (
        open("../data/text-matching/prompt.txt", "r")
        .read()
        .format(products=products, price_tags=prices)
    )
    response = call_llm(prompt, model)
    data = []
    rows = []
    cols = []
    for upc, prices in response.items():
        product_indices = product_groups[upc]
        _chain = []
        for price in prices:
            _chain.append(prices_grouped[(price["price_type"], price["contents"])])
        prices_indices = list(chain.from_iterable(_chain))
        data.append((product_indices, prices_indices))

    for sources, targets in data:
        for s in sources:
            for t in targets:
                rows.append(s)
                cols.append(t)

    return torch.tensor([rows, cols], dtype=torch.long)


def compute_precision_recall_undirected(
    pred_edges: torch.Tensor, gt_edges: torch.Tensor
) -> tuple[float, float, float]:
    def normalize_edges(edges: torch.Tensor) -> set:
        mask = (edges >= 0).all(dim=0)
        edges = edges[:, mask]
        return {tuple(sorted(e)) for e in edges.t().tolist()}

    pred_set = normalize_edges(pred_edges)
    gt_set = normalize_edges(gt_edges)
    tp = len(pred_set & gt_set)
    fp = len(pred_set - gt_set)
    fn = len(gt_set - pred_set)
    pr = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    re = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1_score = (2 * pr * re) / (pr + re) if (pr + re) > 0 else 0.0

    return pr, re, f1_score

In [61]:
all_prec = []
all_rec = []
all_f1 = []
test_dataset = dataset
for G in tqdm(test_dataset, total=len(test_dataset)):
    gt = G.gt_prod_price_edge_index
    pred = get_llm_prediction(G, model="gemini-2.0-flash-lite-001")
    # pred = get_llm_prediction(G, model="gemini-2.5-flash")
    precision, recall, f1 = compute_precision_recall_undirected(
        pred_edges=pred, gt_edges=gt
    )
    all_prec.append(precision)
    all_rec.append(recall)
    all_f1.append(f1)

mean_precision, std_precision = (
    np.array(all_prec).mean(),
    np.array(all_prec).std(),
)
mean_recall, std_recall = np.array(all_rec).mean(), np.array(all_rec).std()
mean_f1, std_f1 = np.array(all_f1).mean(), np.array(all_f1).std()

100%|██████████| 325/325 [08:36<00:00,  1.59s/it]


In [62]:
print(
    {
        "precision": {
            "mean": f"{mean_precision:.3f}",
            "std": f"{std_precision:.3f}",
        },
        "recall": {
            "mean": f"{mean_recall:.3f}",
            "std": f"{std_recall:.3f}",
        },
        "f1": {
            "mean": f"{mean_f1:.3f}",
            "std": f"{std_f1:.3f}",
        },
        "model": "gemini-2.5-flash-lite-001",
    }
)

{'precision': {'mean': '0.604', 'std': '0.392'}, 'recall': {'mean': '0.610', 'std': '0.394'}, 'f1': {'mean': '0.568', 'std': '0.363'}, 'model': 'gemini-2.5-flash-lite-001'}


In [61]:
print("Flash2.5: ", mean_precision, mean_recall, mean_f1)

Flash2.5:  0.27301191841054256 0.3991116427432217 0.3174188924424244


In [51]:
G.gt_prod_price_edge_index.T

tensor([[ 5, 50],
        [14, 50],
        [48, 50],
        [13, 50],
        [20, 50],
        [25, 50],
        [28, 50],
        [ 6, 50],
        [ 3, 50],
        [29, 50],
        [45, 50],
        [23, 50],
        [ 9, 50],
        [ 7, 50],
        [33, 50],
        [46, 50],
        [15, 50],
        [49, 50],
        [ 4, 50],
        [31, 50],
        [10, 50],
        [ 0, 50],
        [19, 50],
        [44, 50],
        [39, 50],
        [22, 50],
        [35, 50],
        [26, 50],
        [43, 50],
        [ 1, 50],
        [34, 50],
        [17, 50],
        [36, 50],
        [32, 50],
        [37, 50],
        [40, 50],
        [ 2, 50],
        [11, 50],
        [47, 50],
        [16, 50],
        [42, 50],
        [ 8, 50],
        [18, 50],
        [38, 50],
        [21, 50],
        [12, 50],
        [27, 50],
        [41, 50],
        [24, 50],
        [30, 50],
        [50,  5],
        [50, 14],
        [50, 48],
        [50, 13],
        [50, 20],
        [5

In [52]:
precision

0.0