In [22]:
import glob
from collections import defaultdict
import time
import json
from google.cloud import storage
from google.cloud import aiplatform_v1

In [27]:
API_ENDPOINT="1148065613.asia-southeast1-629242692180.vdb.vertexai.goog"
INDEX_ENDPOINT="projects/629242692180/locations/asia-southeast1/indexEndpoints/43567048738996224"
DEPLOYED_INDEX_ID="equipment_profile_1766140195503"
TEST_DATA_PATH="gs://axmt_equipment_profile/siglip_vectors_test"

client_options = {
  "api_endpoint": API_ENDPOINT
}
vector_search_client = aiplatform_v1.MatchServiceClient(
  client_options=client_options,
)

In [50]:
def evaluate_equipment_test_dataset_from_embeddings(test_data_path=TEST_DATA_PATH,
                                     index_endpoint=INDEX_ENDPOINT, 
                                     deployed_index_id=DEPLOYED_INDEX_ID, 
                                     num_neighbors=5,
                                     max_images_per_class=None):
    # Initialize results tracking
    results = {
        'total_queries': 0,
        'correct_predictions': 0,
        'class_results': defaultdict(lambda: {'total': 0, 'correct': 0, 'predictions': []}),
        'detailed_results': []
    }
    
    # Load test embeddings from GCS
    test_embeddings = []
    
    print(f"Loading test embeddings from GCS path: {test_data_path}")
    
    # Parse GCS path
    gcs_path = test_data_path.replace('gs://', '')
    bucket_name = gcs_path.split('/')[0]
    blob_prefix = '/'.join(gcs_path.split('/')[1:]) if len(gcs_path.split('/')) > 1 else ''
    
    # Initialize GCS client
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    
    # List all JSON files in the GCS path
    blobs = bucket.list_blobs(prefix=blob_prefix)
    
    for blob in blobs:
        if blob.name.endswith('.json'):
            print(f"Processing GCS file: {blob.name}")
            
            try:
                # Download and parse JSON content
                content = blob.download_as_text()
                
                for line in content.splitlines():
                    if line.strip():
                        try:
                            embedding_data = json.loads(line)
                            meta_data = embedding_data.get('embedding_metadata', '')
                            
                            # Extract relevant information
                            original_path = meta_data.get('original_path', '')
                            label_class_val = meta_data.get('label_class', '')
                            embedding_vector = embedding_data.get('embedding', [])
                            data_id = embedding_data.get('id', '')
                            
                            # Filter for test images (assuming they contain 'test' in path)
                            test_embeddings.append({
                                    'path': original_path,
                                    'class': label_class_val,
                                    'id': data_id,
                                    'embedding': embedding_vector
                                })    
                                
                        except json.JSONDecodeError as e:
                            print(f"Error parsing JSON line: {e}")
                            continue
                            
            except Exception as e:
                print(f"Error processing GCS blob {blob.name}: {e}")
                continue
                
    if not test_embeddings:
        raise ValueError("No test embeddings found. Please check your GCS path contains test data.")
    
    print(f"Found {len(test_embeddings)} test embeddings from GCS")
    
    # Group by class for max_images_per_class filtering
    embeddings_by_class = defaultdict(list)
    for emb_data in test_embeddings:
        embeddings_by_class[emb_data['class']].append(emb_data)
    
    # Apply max_images_per_class limit if specified
    if max_images_per_class:
        filtered_embeddings = []
        for class_name, class_embeddings in embeddings_by_class.items():
            filtered_embeddings.extend(class_embeddings[:max_images_per_class])
        test_embeddings = filtered_embeddings
        
    # Get unique classes
    equipment_classes = list(set([emb['class'] for emb in test_embeddings]))
    
    print(f"Found {len(equipment_classes)} equipment classes: {sorted(equipment_classes)}")
    print(f"Total test embeddings: {len(test_embeddings)}")
    print(f"Starting evaluation with {num_neighbors} neighbors per query...\n")
    
    for emb_data in test_embeddings:
        original_path = emb_data['path']
        true_class = emb_data['class']
        query_embedding = emb_data['embedding']
        
        try:
            # Create vector search request using pre-computed embedding
            datapoint = aiplatform_v1.IndexDatapoint(
                feature_vector=query_embedding
            )

            query = aiplatform_v1.FindNeighborsRequest.Query(
                datapoint=datapoint,
                neighbor_count=num_neighbors
            )

            request = aiplatform_v1.FindNeighborsRequest(
                index_endpoint=index_endpoint,
                deployed_index_id=deployed_index_id,
                queries=[query],
                return_full_datapoint=True,
            )

            # Execute the request
            response = vector_search_client.find_neighbors(request)
            
            if response and response.nearest_neighbors:
                # Get the top prediction
                neighbors = response.nearest_neighbors[0].neighbors
                
                # Extract predicted classes from neighbors
                predicted_classes = []
                neighbor_details = []
                
                for neighbor in neighbors:
                    neighbor_id = neighbor.datapoint.datapoint_id
                    distance = neighbor.distance
                    metadata = neighbor.datapoint.embedding_metadata
                    if metadata:
                        label = neighbor.datapoint.embedding_metadata.get('label_class')
                        original_path = neighbor.datapoint.embedding_metadata.get('original_path')
                    
                    # Get the predicted class from the neighbor ID
                    pred_class = label
                    predicted_classes.append(pred_class)
                    neighbor_details.append({
                        'id': neighbor_id,
                        'predicted_class': pred_class,
                        'distance': distance,
                        'original_path': original_path
                    })    

                # Determine if prediction is correct (top-1 and top-k accuracy)
                top1_correct = predicted_classes[0] == true_class if predicted_classes else False
                topk_correct = true_class in predicted_classes[:num_neighbors] if predicted_classes else False

                # Update results
                results['total_queries'] += 1
                results['class_results'][true_class]['total'] += 1

                if top1_correct:
                    results['correct_predictions'] += 1
                    results['class_results'][true_class]['correct'] += 1
                    
                # Store detailed results
                result_detail = {
                    'query_image': original_path,
                    'true_class': true_class,
                    'predicted_classes': predicted_classes,
                    'top1_correct': top1_correct,
                    'topk_correct': topk_correct,
                    'neighbors': neighbor_details
                }
                results['detailed_results'].append(result_detail)
                results['class_results'][true_class]['predictions'].append(result_detail)
                
                # Print progress
                if results['total_queries'] % 100 == 0:
                    current_accuracy = (results['correct_predictions'] / results['total_queries']) * 100
                    print(f"Processed {results['total_queries']} embeddings, Current accuracy: {current_accuracy:.2f}%")
            
            else:
                print(f"No response for embedding: {original_path}")
                
            # Small delay to avoid overwhelming the API
            time.sleep(0.1)
            
        except Exception as e:
            print(f"Error processing embedding {original_path}: {e}")
            continue
            
    # Calculate final metrics
    overall_accuracy = (results['correct_predictions'] / results['total_queries']) * 100 if results['total_queries'] > 0 else 0
    
    # Calculate per-class accuracy
    class_accuracies = {}
    for class_name, class_data in results['class_results'].items():
        if class_data['total'] > 0:
            class_accuracies[class_name] = (class_data['correct'] / class_data['total']) * 100
        else:
            class_accuracies[class_name] = 0
    
    # Add summary to results
    results['summary'] = {
        'overall_accuracy': overall_accuracy,
        'class_accuracies': class_accuracies,
        'total_classes': len(equipment_classes),
        'avg_class_accuracy': sum(class_accuracies.values()) / len(class_accuracies) if class_accuracies else 0
    }
            
    return results
    

In [51]:
print("Running evaluation using pre-computed embeddings from GCS ...")
results = evaluate_equipment_test_dataset_from_embeddings(
    test_data_path=TEST_DATA_PATH,  # GCS path with embedding JSON files
    index_endpoint=INDEX_ENDPOINT,
    deployed_index_id=DEPLOYED_INDEX_ID,
    num_neighbors=5,
    max_images_per_class=None
)
print("Finish!!")


Running evaluation using pre-computed embeddings from GCS ...
Loading test embeddings from GCS path: gs://axmt_equipment_profile/siglip_vectors_test
Processing GCS file: siglip_vectors_test/local_image_vectors_test.json
Found 537 test embeddings from GCS
Found 21 equipment classes: ['AI1', 'AI10', 'AI11', 'AI12', 'AI13', 'AI14', 'AI15', 'AI16', 'AI17', 'AI18', 'AI19', 'AI2', 'AI21', 'AI22', 'AI3', 'AI4', 'AI5', 'AI6', 'AI7', 'AI8', 'AI9']
Total test embeddings: 537
Starting evaluation with 5 neighbors per query...

Processed 100 embeddings, Current accuracy: 100.00%
Processed 200 embeddings, Current accuracy: 100.00%
Processed 300 embeddings, Current accuracy: 100.00%
Processed 400 embeddings, Current accuracy: 100.00%
Processed 500 embeddings, Current accuracy: 100.00%
Finish!!


In [54]:
def display_evaluation_results(results):
    """
    Display comprehensive evaluation results in a formatted way.
    
    Args:
        results (dict): Evaluation results containing accuracy metrics and detailed results
    """
    print("=" * 80)
    print("EQUIPMENT PROFILE VECTOR SEARCH - EVALUATION RESULTS")
    print("=" * 80)
    
    # Overall metrics
    summary = results.get('summary', {})
    print(f"\nðŸ“Š OVERALL PERFORMANCE:")
    print(f"   Total Queries: {results.get('total_queries', 0)}")
    print(f"   Correct Predictions: {results.get('correct_predictions', 0)}")
    print(f"   Overall Accuracy: {summary.get('overall_accuracy', 0):.2f}%")
    print(f"   Average Class Accuracy: {summary.get('avg_class_accuracy', 0):.2f}%")
    print(f"   Total Classes: {summary.get('total_classes', 0)}")
    
    # Per-class accuracy
    print(f"\nðŸŽ¯ PER-CLASS ACCURACY:")
    class_accuracies = summary.get('class_accuracies', {})
    class_results = results.get('class_results', {})
    
    if class_accuracies:
        # Sort classes by accuracy (descending)
        sorted_classes = sorted(class_accuracies.items(), key=lambda x: x[1], reverse=True)
        
        print(f"{'Class':<8} {'Accuracy':<10} {'Correct':<8} {'Total':<8}")
        print("-" * 40)
        
        for class_name, accuracy in sorted_classes:
            class_data = class_results.get(class_name, {})
            correct = class_data.get('correct', 0)
            total = class_data.get('total', 0)
            print(f"{class_name:<8} {accuracy:>7.2f}%  {correct:>6} / {total:<6}")
    
    # Performance insights
#     print(f"\nðŸ’¡ INSIGHTS:")
#     if class_accuracies:
#         best_class = max(class_accuracies.items(), key=lambda x: x[1])
#         worst_class = min(class_accuracies.items(), key=lambda x: x[1])
        
#         print(f"   Best Performing Class: {best_class[0]} ({best_class[1]:.2f}%)")
#         print(f"   Worst Performing Class: {worst_class[0]} ({worst_class[1]:.2f}%)")
        
#         # Calculate performance distribution
#         high_perf = sum(1 for acc in class_accuracies.values() if acc >= 90)
#         medium_perf = sum(1 for acc in class_accuracies.values() if 70 <= acc < 90)
#         low_perf = sum(1 for acc in class_accuracies.values() if acc < 70)
        
#         print(f"   High Performance (â‰¥90%): {high_perf} classes")
#         print(f"   Medium Performance (70-89%): {medium_perf} classes")
#         print(f"   Low Performance (<70%): {low_perf} classes")
    
    print("\n" + "=" * 80)

In [55]:
display_evaluation_results(results)

EQUIPMENT PROFILE VECTOR SEARCH - EVALUATION RESULTS

ðŸ“Š OVERALL PERFORMANCE:
   Total Queries: 537
   Correct Predictions: 537
   Overall Accuracy: 100.00%
   Average Class Accuracy: 100.00%
   Total Classes: 21

ðŸŽ¯ PER-CLASS ACCURACY:
Class    Accuracy   Correct  Total   
----------------------------------------
AI1       100.00%      26 / 26    
AI10      100.00%      28 / 28    
AI11      100.00%      16 / 16    
AI12      100.00%      19 / 19    
AI13      100.00%      21 / 21    
AI14      100.00%      20 / 20    
AI15      100.00%      33 / 33    
AI16      100.00%      22 / 22    
AI17      100.00%      21 / 21    
AI18      100.00%      26 / 26    
AI19      100.00%      40 / 40    
AI2       100.00%      42 / 42    
AI21      100.00%      20 / 20    
AI22      100.00%      18 / 18    
AI3       100.00%      26 / 26    
AI4       100.00%      33 / 33    
AI5       100.00%      27 / 27    
AI6       100.00%      21 / 21    
AI7       100.00%      23 / 23    
AI8       100.0

In [56]:
results['class_results']['AI1']

{'total': 26,
 'correct': 26,
 'predictions': [{'query_image': 'dataset/equipment_train/AI1/20251112_112315(4).jpg',
   'true_class': 'AI1',
   'predicted_classes': ['AI1', 'AI1', 'AI1', 'AI1', 'AI1'],
   'top1_correct': True,
   'topk_correct': True,
   'neighbors': [{'id': 'AI1_09ca1946-dd9f-4373-9892-1a45dd09a3e8',
     'predicted_class': 'AI1',
     'distance': 0.9917260408401489,
     'original_path': 'dataset/equipment_train/AI1/20251112_112314.jpg'},
    {'id': 'AI1_f40fe8f3-0cec-4a74-90a2-fa3bce4162c2',
     'predicted_class': 'AI1',
     'distance': 0.9877064228057861,
     'original_path': 'dataset/equipment_train/AI1/20251112_112315.jpg'},
    {'id': 'AI1_2d420048-a5e0-4f6d-b262-ce4c480ae776',
     'predicted_class': 'AI1',
     'distance': 0.9814499616622925,
     'original_path': 'dataset/equipment_train/AI1/20251112_112315(0).jpg'},
    {'id': 'AI1_cda0eb04-066d-47a6-b66e-59e8100cf5d7',
     'predicted_class': 'AI1',
     'distance': 0.9752538204193115,
     'original_pat