# Example: Multimodal (Vision) Finetuning


In [None]:
import asyncio
import base64
import os
from os import PathLike
from pathlib import Path

import numpy as np
import pandas as pd
from tensorzero import AsyncTensorZeroGateway
from tqdm.asyncio import tqdm_asyncio

In [None]:
DATA_PATH = Path("data")

LABELS_PATH = DATA_PATH / "labels.csv"

CONCURRENCY = 10

VARIANT_NAME = "baseline"

In [None]:
def load_data(path: PathLike):
    assert LABELS_PATH.exists(), (
        f"Labels file {LABELS_PATH} does not exist. See the README.md and ensure you've downloaded the dataset correctly."
    )

    df = pd.read_csv(LABELS_PATH)

    # Sanity Check: ensure every image exists
    for _, row in df.iterrows():
        img_path = path / Path(row["document"])
        assert img_path.exists(), (
            f"Image {img_path} does not exist. See the README.md and ensure you've downloaded the dataset correctly."
        )

    train_df = df[df["is_train"] == 1].reset_index(drop=True)
    test_df = df[df["is_train"] == 0].reset_index(drop=True)

    return train_df, test_df


train_df, test_df = load_data(DATA_PATH)

print(f"Found {len(train_df)} train documents and {len(test_df)} test documents")

In [None]:
train_df.sample(5)

In [None]:
test_df.sample(5)

In [None]:
os.makedirs("tensorzero/object_storage", exist_ok=True)

In [None]:
t0 = await AsyncTensorZeroGateway.build_http(
    gateway_url="http://localhost:3000",
)

In [None]:
def load_document(path: PathLike):
    """Load an image and encode as a base64 string"""
    path = DATA_PATH / path
    assert path.exists()
    assert path.suffix.lower() == ".png"

    with open(path, "rb") as image_file:
        return base64.b64encode(image_file.read()).decode("utf-8")

In [None]:
async def process_document(row):
    response = await t0.inference(
        function_name="classify_document",
        input={
            "system": "Categorize this document.",
            "messages": [
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "image",
                            "mime_type": "image/png",
                            "data": load_document(row.document),
                        },
                    ],
                }
            ],
        },
        dryrun=not row.is_train,
        cache_options={
            "enabled": "on",
        },
        variant_name=VARIANT_NAME,
    )

    predicted_category = response.output.parsed["category"]
    correct_classification = predicted_category == row.label

    if row.is_train:
        await t0.feedback(
            metric_name="correct_classification",
            value=correct_classification,
            inference_id=response.inference_id,
        )

        await t0.feedback(
            metric_name="demonstration",
            value={
                "category": row.label,
            },
            inference_id=response.inference_id,
        )

    return correct_classification

In [None]:
semaphore = asyncio.Semaphore(CONCURRENCY)


async def process_document_with_semaphore(row):
    async with semaphore:
        return await process_document(row)

In [None]:
scores = await tqdm_asyncio.gather(
    *[process_document_with_semaphore(row) for _, row in train_df.iterrows()]
)

print(f"Train Set Accuracy: {np.mean(scores):.1%}")

In [None]:
scores = await tqdm_asyncio.gather(
    *[process_document_with_semaphore(row) for _, row in test_df.iterrows()]
)

print(f"Test Set Accuracy: {np.mean(scores):.1%}")