In [1]:
import asyncio
import base64
import os
from os import PathLike
from pathlib import Path

import numpy as np
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import tqdm_asyncio

In [2]:
DATA_PATH = Path("data")

LABELS_PATH = DATA_PATH / "labels.csv"

CONCURRENCY = 10

VARIANT_NAME = "baseline"

In [3]:
def load_data(path: PathLike):
    df = pd.read_csv(LABELS_PATH)

    # Sanity Check: ensure every image exists
    for _, row in df.iterrows():
        img_path = path / Path(row["document"])
        assert img_path.exists(), (
            f"Image {img_path} does not exist. Ensure you've downloaded the dataset correctly."
        )

    train_df = df[df["is_train"] == 1].reset_index(drop=True)
    test_df = df[df["is_train"] == 0].reset_index(drop=True)

    return train_df, test_df


train_df, test_df = load_data(DATA_PATH)

print(f"Found {len(train_df)} train documents and {len(test_df)} test documents")

Found 200 train documents and 200 test documents


In [4]:
train_df.sample(5)

Unnamed: 0,document,label,is_train
186,dataset/train/0903.3513v1.Fuzzy_Chemical_Abstr...,cs.FL,1
75,dataset/train/0706.2748v2.A_Survey_of_Unix_Ini...,cs.OS,1
65,dataset/train/cs_0001004v1.Multiplicative_Algo...,cs.LG,1
188,dataset/train/0904.3366v1.State_complexity_of_...,cs.FL,1
66,dataset/train/cs_0002006v1.Multiplicative_Nonh...,cs.LG,1


In [5]:
test_df.sample(5)

Unnamed: 0,document,label,is_train
37,dataset/test/cs_0109084v1.The_Internet_and_Com...,cs.DB,0
98,dataset/test/cs_0204051v1.Parrondo_Strategies_...,cs.CE,0
65,dataset/test/cs_0011044v1.Scaling_Up_Inductive...,cs.LG,0
104,dataset/test/cs_0109009v1.The_Effect_of_Native...,cs.CY,0
105,dataset/test/1011.1519v1.Fuzzy_Controller_for_...,cs.SY,0


In [6]:
os.makedirs("tensorzero/object_storage", exist_ok=True)

In [7]:
t0 = await AsyncTensorZeroGateway.build_http(
    gateway_url="http://localhost:3000",
)

In [8]:
def load_document(path: PathLike):
    """Load an image and encode as a base64 string"""
    path = DATA_PATH / path
    assert path.exists()
    assert path.suffix.lower() == ".png"

    with open(path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [9]:
async def process_document(row):
    response = await t0.inference(
        function_name="classify_document",
        input={
            "system": "Categorize this document.",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "mime_type": "image/png",
                            "data": load_document(row.document),
                        },
                    ],
                }
            ],
        },
        dryrun=not row.is_train,
        cache_options={
            "enabled": "on",
        },
    )

    correct_classification = response.output.parsed == row.label

    if row.is_train:
        await t0.feedback(
            metric_name="correct_classification",
            value=correct_classification,
            inference_id=response.inference_id,
        )

        await t0.feedback(
            metric_name="demonstration",
            value={
                "category": row.label,
            },
            inference_id=response.inference_id,
        )

    return correct_classification

In [10]:
semaphore = asyncio.Semaphore(CONCURRENCY)


async def process_document_with_semaphore(row):
    async with semaphore:
        return await process_document(row)

In [11]:
scores = await tqdm_asyncio.gather(
    *[process_document_with_semaphore(row) for _, row in train_df.iterrows()]
)


print(f"Train Set Accuracy: {np.mean(scores):.1%}")

100%|██████████| 200/200 [01:37<00:00,  2.06it/s]

Train Set Accuracy: 0.0%





In [12]:
scores = await tqdm_asyncio.gather(
    *[process_document_with_semaphore(row) for _, row in test_df.iterrows()]
)

print(f"Test Set Accuracy: {np.mean(scores):.1%}")

100%|██████████| 200/200 [01:46<00:00,  1.88it/s]

Test Set Accuracy: 0.0%



