In [1]:
import glob
from collections import defaultdict
import time
import json
from google.cloud import storage
from google.cloud import aiplatform_v1

In [2]:
API_ENDPOINT="1148065613.asia-southeast1-629242692180.vdb.vertexai.goog"
INDEX_ENDPOINT="projects/629242692180/locations/asia-southeast1/indexEndpoints/43567048738996224"
DEPLOYED_INDEX_ID="equipment_profile_1766140195503"
TEST_DATA_PATH="gs://axmt_equipment_profile/siglip_vectors_test"

client_options = {
  "api_endpoint": API_ENDPOINT
}
vector_search_client = aiplatform_v1.MatchServiceClient(
  client_options=client_options,
)

In [3]:
def average_precision(predicted_classes, true_class):
    """
    Standard Average Precision (AP) for a single retrieval query.
    """
    relevant_count = 0
    precision_sum = 0.0

    for rank, pred_class in enumerate(predicted_classes, start=1):
        if pred_class == true_class:
            relevant_count += 1
            precision_sum += relevant_count / rank

    return precision_sum / relevant_count if relevant_count > 0 else 0.0


In [16]:
def load_embeddings_from_gcs(gcs_path):
    """
    Load JSONL embeddings from a GCS directory.
    """
    storage_client = storage.Client()

    gcs_path = gcs_path.replace("gs://", "")
    bucket_name, *prefix_parts = gcs_path.split("/")
    prefix = "/".join(prefix_parts)

    bucket = storage_client.bucket(bucket_name)
    blobs = bucket.list_blobs(prefix=prefix)

    embeddings = []

    for blob in blobs:
        if not blob.name.endswith(".json"):
            continue

        content = blob.download_as_text()
        for line in content.splitlines():
            if not line.strip():
                continue

            data = json.loads(line)
            metadata = data.get("embedding_metadata", {})

            label_class = metadata.get("label_class")
            if not label_class:
                # Skip invalid records explicitly
                continue

            embeddings.append({
                "id": data.get("id"),
                "embedding": data.get("embedding"),
                "true_class": label_class,
                "path": metadata.get("original_path")
            })

    if not embeddings:
        raise ValueError("No valid embeddings with label_class found")

    return embeddings

In [17]:
def query_vector_index(
    vector_search_client,
    index_endpoint,
    deployed_index_id,
    embedding,
    num_neighbors
):
    """
    Execute a vector similarity search and return ranked neighbors.
    """
    datapoint = aiplatform_v1.IndexDatapoint(feature_vector=embedding)

    query = aiplatform_v1.FindNeighborsRequest.Query(
        datapoint=datapoint,
        neighbor_count=num_neighbors
    )

    request = aiplatform_v1.FindNeighborsRequest(
        index_endpoint=index_endpoint,
        deployed_index_id=deployed_index_id,
        queries=[query],
        return_full_datapoint=True,
    )

    response = vector_search_client.find_neighbors(request)

    if not response.nearest_neighbors:
        return []

    return response.nearest_neighbors[0].neighbors


In [20]:
def evaluate_embeddings(
    test_data_path,
    vector_search_client,
    index_endpoint,
    deployed_index_id,
    num_neighbors=5,
    max_images_per_class=None
):
    """
    Evaluate vector retrieval using Top-1 Accuracy and mAP.
    """

    # ---------- Load data ----------
    embeddings = load_embeddings_from_gcs(test_data_path)

    if max_images_per_class:
        by_class = defaultdict(list)
        for e in embeddings:
            by_class[e["class"]].append(e)

        embeddings = [
            emb
            for class_embs in by_class.values()
            for emb in class_embs[:max_images_per_class]
        ]

    # ---------- Metrics ----------
    total_queries = 0
    top1_correct = 0
    average_precisions = []

    per_class = defaultdict(lambda: {
        "total": 0,
        "top1_correct": 0,
        "aps": []
    })

    detailed_results = []

    # ---------- Evaluation loop ----------
    for emb in embeddings:
        true_class = emb["true_class"]

        neighbors = query_vector_index(
            vector_search_client,
            index_endpoint,
            deployed_index_id,
            emb["embedding"],
            num_neighbors
        )

        if not neighbors:
            continue

        predicted_classes = []
        neighbor_details = []

        for n in neighbors:
            meta = n.datapoint.embedding_metadata or {}
            pred_class = meta.get("label_class")

            predicted_classes.append(pred_class)
            neighbor_details.append({
                "id": n.datapoint.datapoint_id,
                "predicted_class": pred_class,
                "distance": n.distance,
                "path": meta.get("original_path")
            })

        # ---------- Metrics ----------
        is_top1 = predicted_classes[0] == true_class
        ap = average_precision(predicted_classes, true_class)

        total_queries += 1
        top1_correct += int(is_top1)
        average_precisions.append(ap)

        per_class[true_class]["total"] += 1
        per_class[true_class]["top1_correct"] += int(is_top1)
        per_class[true_class]["aps"].append(ap)

        detailed_results.append({
            "query_path": emb["path"],
            "true_class": true_class,
            "predicted_classes": predicted_classes,
            "top1_correct": is_top1,
            "average_precision": ap,
            "neighbors": neighbor_details
        })

    # ---------- Final metrics ----------
    top1_accuracy = top1_correct / total_queries if total_queries else 0.0
    mean_ap = sum(average_precisions) / len(average_precisions) if average_precisions else 0.0

    class_metrics = {
        cls: {
            "top1_accuracy": v["top1_correct"] / v["total"] if v["total"] else 0.0,
            "mAP": sum(v["aps"]) / len(v["aps"]) if v["aps"] else 0.0
        }
        for cls, v in per_class.items()
    }

    return {
        "summary": {
            "top1_accuracy": top1_accuracy,
            "mAP": mean_ap,
            "total_queries": total_queries,
            "num_classes": len(class_metrics)
        },
        "per_class": class_metrics,
        "details": detailed_results
    }


In [21]:
print("Running evaluation using pre-computed embeddings from GCS ...")

results = evaluate_embeddings(
    test_data_path=TEST_DATA_PATH,
    vector_search_client=vector_search_client,
    index_endpoint=INDEX_ENDPOINT,
    deployed_index_id=DEPLOYED_INDEX_ID,
    num_neighbors=5,
    max_images_per_class=None
)

print("Finished evaluation.")


Running evaluation using pre-computed embeddings from GCS ...
Finished evaluation.


In [33]:
def display_evaluation_results(
    results,
    inspection: bool = False,
    topk: int = 5,
    max_rows: int = 20,
):
    """
    Display evaluation results.

    Args:
        results (dict): Evaluation output from evaluator
        inspection (bool): Show label -> prediction inspection table
        topk (int): Number of top predictions to display
        max_rows (int): Max inspection rows to print
    """
    summary = results.get("summary", {})
    per_class = results.get("per_class", {})
    details = results.get("details", [])

    print("=" * 100)
    print("EQUIPMENT PROFILE VECTOR SEARCH - EVALUATION RESULTS")
    print("=" * 100)

    # ---------- Overall ----------
    print("\nOVERALL PERFORMANCE")
    print("-" * 100)
    print(f"Total Queries     : {summary.get('total_queries', 0)}")
    print(f"Total Classes     : {summary.get('num_classes', 0)}")
    print(f"Top-1 Accuracy    : {summary.get('top1_accuracy', 0) * 100:.2f}%")
    print(f"Mean AP (mAP)     : {summary.get('mAP', 0):.4f}")

    # ---------- Per-class ----------
    print("\nPER-CLASS METRICS")
    print("-" * 100)
    print(f"{'Class':<15} {'Top-1 (%)':>12} {'mAP':>10}")
    print("-" * 100)

    def safe(v, default=0.0):
        return v if isinstance(v, (int, float)) else default

    for cls, metrics in sorted(
        per_class.items(),
        key=lambda x: safe(x[1].get("mAP")),
        reverse=True
    ):
        print(
            f"{cls:<15} "
            f"{safe(metrics.get('top1_accuracy')) * 100:>12.2f} "
            f"{safe(metrics.get('mAP')):>10.4f}"
        )
    
       # ---------- Optional inspection ----------
    if inspection:
        print("\nLABEL → PREDICTION INSPECTION")
        print("-" * 100)
        print(
            f"{'True Label':<15} "
            f"{'Top-1 Pred':<15} "
            f"{'Top-5 Predictions':<40} "
            f"{'AP':>6}"
        )
        print("-" * 100)

        for row in details[:max_rows]:
            true_label = row.get("true_class", "N/A")
            preds = row.get("predicted_classes", [])
            ap = row.get("average_precision", 0.0)

            top1 = preds[0] if preds else "N/A"
            top5 = preds[:topk]

            print(
                f"{true_label:<15} "
                f"{top1:<15} "
                f"{str(top5):<40} "
                f"{ap:>6.3f}"
            )

        if len(details) > max_rows:
            print(f"\n... showing first {max_rows} of {len(details)} queries")

    print("\n" + "=" * 100)

In [34]:
display_evaluation_results(results, inspection=True)

EQUIPMENT PROFILE VECTOR SEARCH - EVALUATION RESULTS

OVERALL PERFORMANCE
----------------------------------------------------------------------------------------------------
Total Queries     : 537
Total Classes     : 21
Top-1 Accuracy    : 100.00%
Mean AP (mAP)     : 1.0000

PER-CLASS METRICS
----------------------------------------------------------------------------------------------------
Class              Top-1 (%)        mAP
----------------------------------------------------------------------------------------------------
AI1                   100.00     1.0000
AI10                  100.00     1.0000
AI11                  100.00     1.0000
AI12                  100.00     1.0000
AI13                  100.00     1.0000
AI14                  100.00     1.0000
AI15                  100.00     1.0000
AI16                  100.00     1.0000
AI17                  100.00     1.0000
AI18                  100.00     1.0000
AI19                  100.00     1.0000
AI2                   