In [54]:
import json
import pandas as pd
from PIL import Image
from pathlib import Path
from gpa.datasets.attribution import PriceAttributionDataset
from tqdm import tqdm
from google import genai
from google.genai import types

In [55]:
dataset_dir = Path("../data/price-graphs-ii/test")
dataset = PriceAttributionDataset(root=dataset_dir)
products_df = pd.read_csv(
    dataset_dir / "raw" / "product_boxes.csv",
    dtype={"ml_label_name": str},
)
price_df = pd.read_csv(
    dataset_dir / "raw" / "price_boxes.csv",
)
joined = price_df.merge(
    products_df[["attributionset_id", "image_bucket_path"]], on="attributionset_id"
)
joined["image_id"] = joined.image_bucket_path.str.split("/").str[-1]
joined.drop(columns=["image_bucket_path"], inplace=True)
price_df = joined.drop_duplicates()
gemini = genai.Client(vertexai=True, project="dai-ultra", location="us-central1")

In [56]:
generate_content_config = types.GenerateContentConfig(
    temperature=1,
    top_p=0.95,
    max_output_tokens=8192,
    response_modalities=["TEXT"],
    response_mime_type="application/json",
)

In [57]:
def get_text_from_crop(crop: Image.Image, model: str):
    import io

    prompt = open("../data/text-matching/product_price_tag_extraction.txt", "r").read()
    buffer = io.BytesIO()
    crop.save(buffer, format="JPEG")
    image_bytes = buffer.getvalue()
    raw = gemini.models.generate_content(
        model=model,
        config=generate_content_config,
        contents=[
            types.Content(
                role="user",
                parts=[
                    types.Part.from_text(text=prompt),
                    types.Part.from_bytes(
                        data=image_bytes,
                        mime_type="image/jpeg",
                    ),
                ],
            )
        ],
    )
    cleaned = raw.text.strip().removeprefix("```json").removesuffix("```").strip()
    return json.loads(cleaned)

In [61]:
with open("../data/text-matching/price_tag_product_metadata.json", "r") as f:
    cache = json.load(f)
for image_id, group in tqdm(price_df.groupby("image_id"), total=325):
    image_path = Path(dataset_dir.as_posix().replace("test", "images")) / image_id
    image = Image.open(image_path)
    for index, row in group.iterrows():
        if row.price_bbox_id in cache:
            continue
        coords = (
            row.min_x * image.width,
            row.min_y * image.height,
            row.max_x * image.width,
            row.max_y * image.height,
        )
        price_crop = image.crop(coords)
        extracted_text = get_text_from_crop(
            price_crop, model="gemini-2.0-flash-lite-001"
        )
        cache[row.price_bbox_id] = extracted_text

100%|██████████| 325/325 [27:43<00:00,  5.12s/it]  


In [62]:
with open("../data/text-matching/price_tag_product_metadata.json", "w") as f:
    json.dump(cache, f, indent=2)
# with open("../data/text-matching/price_tag_product_metadata.json", "r") as f:
#     cache = json.load(f)

In [60]:
with open("../data/text-matching/price_tag_product_metadata.json", "w") as f:
    json.dump(cache, f, indent=2)