Copy the long code block below to test any pipeline with the normal 4 inputs (image, dining_hall_id, mealtime, date) and list of outputs (items + num_servings per item).

The pipeline **must** define a `predict()` function with the following definition:

```python3

@dataclass
class PredictedItem:
    id: str
    num_servings: float

def predict(image, dining_hall_id, meal_time, date) -> [PredictedItem]:
    return [
      PredictedItem(id="12345", num_servings=1.0),
      PredictedItem(id="67890", num_servings=0.5),
    ]
```

In [None]:
# ---- CONFIG ----
from getpass import getpass

API_BASE_URL = "https://3vw53n9900.execute-api.us-east-1.amazonaws.com/dev"
API_TOKEN = getpass("Enter MenuMatch API token: ")

HUSKYEATS_BASE_URL = "https://husky-eats.onrender.com/api"

# ---- IMPORTS ----
import requests
from dataclasses import dataclass
from typing import List, Optional, Dict, Any, Iterable, Tuple
from io import BytesIO
import math

from PIL import Image


# ---- LOW-LEVEL API HELPERS ----

def _auth_headers() -> Dict[str, str]:
    return {"X-Api-Key": API_TOKEN}


def fetch_dataset_metadata() -> List[Dict[str, Any]]:
    url = f"{API_BASE_URL}/dataset"
    resp = requests.get(url, headers=_auth_headers())
    resp.raise_for_status()
    data = resp.json()
    return data.get("items", data)


def get_download_url(object_key: str, bucket: Optional[str] = None) -> str:
    url = f"{API_BASE_URL}/downloads/presign"
    payload: Dict[str, Any] = {"objectKey": object_key}
    if bucket:
        payload["bucket"] = bucket

    resp = requests.post(url, headers=_auth_headers(), json=payload)
    resp.raise_for_status()
    data = resp.json()
    return data["downloadUrl"]


def load_image(object_key: str, bucket: Optional[str] = None) -> Image.Image:
    download_url = get_download_url(object_key, bucket=bucket)
    resp = requests.get(download_url)
    resp.raise_for_status()
    img = Image.open(BytesIO(resp.content)).convert("RGB")
    return img


def get_nutrition_for_id(menu_item_id):
    # HuskyEats: GET /menuitem/{id}
    url = f"{HUSKYEATS_BASE_URL}/menuitem/{menu_item_id}"
    resp = requests.get(url)
    resp.raise_for_status()
    data = resp.json()

    return {
        "kcal": float(data["calories"]),
        "protein_g": float(data["protein_g"]),
        "carb_g": float(data["totalcarbohydrate_g"]),
        "fat_g": float(data["totalfat_g"]),
    }


# ---- DATA STRUCTURES ----

@dataclass
class GroundTruthItem:
    id: str
    num_servings: float


@dataclass
class Sample:
    object_key: str
    image: Image.Image
    dining_hall_id: str
    meal_time: str
    date: str
    difficulty: Optional[str]
    ground_truth: List[GroundTruthItem]


def iter_samples(limit: Optional[int] = None) -> Iterable[Sample]:
    metadata_items = fetch_dataset_metadata()
    if limit is not None:
        metadata_items = metadata_items[:limit]

    for meta in metadata_items:
        object_key = meta["objectKey"]
        bucket = meta.get("bucket")

        img = load_image(object_key, bucket=bucket)

        gt_items = [
            GroundTruthItem(
                id=str(item["menuItemId"]),
                num_servings=float(item["servings"]),
            )
            for item in meta.get("items", [])
        ]

        yield Sample(
            object_key=object_key,
            image=img,
            dining_hall_id=str(meta.get("diningHallId")),
            meal_time=str(meta.get("mealtime")),
            date=str(meta.get("mealDate")),
            difficulty=meta.get("difficulty"),
            ground_truth=gt_items,
        )


# ---- RUNNER ----

def run_model_on_dataset(
    predict_fn=None,
    limit: Optional[int] = None,
):
    if predict_fn is None:
        try:
            predict_fn = globals()["predict"]
        except KeyError:
            raise ValueError("No predict_fn provided and no global `predict` defined.")

    results = []

    for sample in iter_samples(limit=limit):
        preds = predict_fn(
            sample.image,
            sample.dining_hall_id,
            sample.meal_time,
            sample.date,
        )

        results.append(
            {
                "object_key": sample.object_key,
                "dining_hall_id": sample.dining_hall_id,
                "meal_time": sample.meal_time,
                "date": sample.date,
                "ground_truth": sample.ground_truth,
                "predictions": preds,
            }
        )

    return results


# ---- METRICS ----

def _items_to_dict(items: Iterable[Any]) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for it in items:
        if hasattr(it, "id"):
            _id = str(it.id)
            servings = float(it.num_servings)
        else:
            _id = str(it["id"])
            servings = float(it.get("num_servings", 0.0))
        out[_id] = servings
    return out


def compute_all_metrics(
    results: List[Dict[str, Any]],
    get_nutrition_for_id: Optional[callable] = None,
    macro_nutrients: Tuple[str, ...] = ("kcal", "protein_g", "carb_g", "fat_g"),
) -> Dict[str, float]:
    tp = fp = fn = 0
    jaccards: List[float] = []
    exact_match_count = 0

    abs_errors: List[float] = []
    sq_errors: List[float] = []
    perc_errors: List[float] = []

    for r in results:
        gt = _items_to_dict(r["ground_truth"])
        pr = _items_to_dict(r["predictions"])

        gt_ids = {k for k, v in gt.items() if v > 0}
        pr_ids = {k for k, v in pr.items() if v > 0}

        inter = gt_ids & pr_ids
        tp += len(inter)
        fp += len(pr_ids - gt_ids)
        fn += len(gt_ids - pr_ids)

        union = gt_ids | pr_ids
        j = len(inter) / len(union) if union else 1.0
        jaccards.append(j)

        if gt_ids == pr_ids:
            exact_match_count += 1

        all_ids = set(gt.keys()) | set(pr.keys())
        for item_id in all_ids:
            g = gt.get(item_id, 0.0)
            p = pr.get(item_id, 0.0)
            err = p - g
            ae = abs(err)
            abs_errors.append(ae)
            sq_errors.append(err * err)
            if g > 0:
                perc_errors.append(ae / g)

    precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    avg_jaccard = sum(jaccards) / len(jaccards) if jaccards else 0.0
    exact_match = exact_match_count / len(results) if results else 0.0

    mae_serv = sum(abs_errors) / len(abs_errors) if abs_errors else 0.0
    rmse_serv = math.sqrt(sum(sq_errors) / len(sq_errors)) if sq_errors else 0.0
    pmae_serv = sum(perc_errors) / len(perc_errors) if perc_errors else 0.0

    metrics: Dict[str, float] = {
        "cls_precision": precision,
        "cls_recall": recall,
        "cls_f1": f1,
        "cls_avg_jaccard": avg_jaccard,
        "cls_exact_match": exact_match,
        "portion_mae_servings": mae_serv,
        "portion_rmse_servings": rmse_serv,
        "portion_pmae_servings": pmae_serv,
    }

    if get_nutrition_for_id is not None:
        nutr_cache: Dict[str, Dict[str, float]] = {}

        def nutr(item_id: str) -> Dict[str, float]:
            if item_id not in nutr_cache:
                nutr_cache[item_id] = get_nutrition_for_id(item_id)
            return nutr_cache[item_id]

        macro_abs_errors = {n: [] for n in macro_nutrients}
        macro_perc_errors = {n: [] for n in macro_nutrients}

        for r in results:
            gt = _items_to_dict(r["ground_truth"])
            pr = _items_to_dict(r["predictions"])

            gt_tot = {n: 0.0 for n in macro_nutrients}
            pr_tot = {n: 0.0 for n in macro_nutrients}

            for item_id, servings in gt.items():
                info = nutr(item_id)
                for n in macro_nutrients:
                    gt_tot[n] += servings * float(info[n])

            for item_id, servings in pr.items():
                info = nutr(item_id)
                for n in macro_nutrients:
                    pr_tot[n] += servings * float(info[n])

            for n in macro_nutrients:
                g = gt_tot[n]
                p = pr_tot[n]
                ae = abs(p - g)
                macro_abs_errors[n].append(ae)
                if g > 0:
                    macro_perc_errors[n].append(ae / g)

        for n in macro_nutrients:
            ae_list = macro_abs_errors[n]
            pe_list = macro_perc_errors[n]

            metrics[f"macro_mae_{n}"] = sum(ae_list) / len(ae_list) if ae_list else 0.0
            metrics[f"macro_pmae_{n}"] = sum(pe_list) / len(pe_list) if pe_list else 0.0

    return metrics