In [4]:
pip install torch torchvision torchaudio transformers sentencepiece gcsfs

Collecting torch
  Downloading torch-2.9.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (30 kB)
Collecting torchvision
  Downloading torchvision-0.24.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (5.9 kB)
Collecting torchaudio
  Downloading torchaudio-2.9.1-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (6.9 kB)
Collecting sympy>=1.13.3 (from torch)
  Downloading sympy-1.14.0-py3-none-any.whl.metadata (12 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.8.93 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.8.93-py3-none-manylinux2010_x86_64.manylinux_2_12_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-runtime-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cuda-cupti-cu12==12.8.90 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.8.90-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (1.7 kB)
Collecting nvidia-cudnn-cu12==9.10.2.21 (from torch)


In [3]:
import os
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModel
from google.cloud import aiplatform
from google.cloud import aiplatform_v1

In [16]:
PROJECT_ID = "629242692180"  
REGION = "asia-southeast1"     

# Set variables for the current deployed index.
API_ENDPOINT="1148065613.asia-southeast1-629242692180.vdb.vertexai.goog"
INDEX_ENDPOINT="projects/629242692180/locations/asia-southeast1/indexEndpoints/43567048738996224"
DEPLOYED_INDEX_ID="equipment_profile_1765274237629"
QUERY_IMAGE_PATH = "datatest/AI2/20251112_111947(0).jpg"
NUM_NEIGHBORS = 5

In [12]:
client_options = {
  "api_endpoint": API_ENDPOINT
}
vector_search_client = aiplatform_v1.MatchServiceClient(
  client_options=client_options,
)

In [13]:
try:
    processor = AutoProcessor.from_pretrained("google/siglip-base-patch16-224")
    model = AutoModel.from_pretrained("google/siglip-base-patch16-224")
    model.eval()
    print("SigLIP model loaded successfully.")
except Exception as e:
    print(f"Error loading SigLIP model: {e}")
    exit()

SigLIP model loaded successfully.


In [14]:
def create_query_embedding(image_path: str) -> list:
    if not os.path.exists(image_path):
        raise FileNotFoundError(f"Query image not found at: {image_path}")

    image = Image.open(image_path).convert("RGB")
    
    # Pre-processing
    inputs = processor(images=image, return_tensors="pt")
    
    # Inference
    with torch.no_grad():
        outputs = model.get_image_features(**inputs)
        query_vector = outputs / outputs.norm(p=2, dim=-1, keepdim=True)
    
    return query_vector.squeeze(0).tolist()

In [17]:
def vector_search(query_image_path: str, index_endpoint: str, deploy_index_id: str, num_neighbors: int):
    try:
        query_embedding_vector = create_query_embedding(query_image_path)
        
        datapoint = aiplatform_v1.IndexDatapoint(
          feature_vector=query_embedding_vector
        )

        query = aiplatform_v1.FindNeighborsRequest.Query(
          datapoint=datapoint,

          # The number of nearest neighbors to be retrieved
          neighbor_count=num_neighbors
        )

        # filter_condition = [
        #     {"namespace": "metadata.label_class", "allow_list": ["AI1"]} # use namespace 'metadata.label_class'
        # ]
        filter_condition = None
        
        try:
            request = aiplatform_v1.FindNeighborsRequest(
              index_endpoint=index_endpoint,
              deployed_index_id=deploy_index_id,
              # Request can have multiple queries
              queries=[query],
              return_full_datapoint=False,
            )

            # Execute the request
            response = vector_search_client.find_neighbors(request)
            
            # Handle the response
            return response
        except Exception as e:
            print(f"ERROR: {e}")
            return
        
    except Exception as e:
        print(f"Error creating embedding: {e}")
        return
    return 

In [19]:
res = vector_search(QUERY_IMAGE_PATH, INDEX_ENDPOINT, DEPLOYED_INDEX_ID, NUM_NEIGHBORS)
res

nearest_neighbors {
  neighbors {
    datapoint {
      datapoint_id: "equipment_train_AI2_20251112_111948.jpg"
      crowding_tag {
        crowding_attribute: "0"
      }
    }
    distance: 0.96325886249542236
  }
  neighbors {
    datapoint {
      datapoint_id: "equipment_train_AI2_20251112_111948(0).jpg"
      crowding_tag {
        crowding_attribute: "0"
      }
    }
    distance: 0.9466739296913147
  }
  neighbors {
    datapoint {
      datapoint_id: "equipment_train_AI2_20251112_111943(0).jpg"
      crowding_tag {
        crowding_attribute: "0"
      }
    }
    distance: 0.94594419002532959
  }
  neighbors {
    datapoint {
      datapoint_id: "equipment_train_AI2_20251112_111947.jpg"
      crowding_tag {
        crowding_attribute: "0"
      }
    }
    distance: 0.93916523456573486
  }
  neighbors {
    datapoint {
      datapoint_id: "equipment_train_AI2_20251112_111949.jpg"
      crowding_tag {
        crowding_attribute: "0"
      }
    }
    distance: 0.938072204589

In [20]:
first_query_result = res.nearest_neighbors[0]
neighbors = first_query_result.neighbors
neighbor = neighbors[0]
print(neighbor) 

datapoint {
  datapoint_id: "equipment_train_AI2_20251112_111948.jpg"
  crowding_tag {
    crowding_attribute: "0"
  }
}
distance: 0.96325886249542236



In [21]:
from google.cloud import storage
import json

client = storage.Client()
bucket = client.bucket("axmt_equipment_profile")
blob = bucket.blob("siglip_vectors/local_image_vectors.json")

content = blob.download_as_text()

original_path = {}
label_class = {}

for l in content.splitlines():
    p = json.loads(l)
    id = p['id']
    original_path[id] = p['original_path']
    label_class[id] = p['label_class']

In [None]:
def evaluate_equipment_test_dataset_from_embeddings(test_data_path="gs://axmt_equipment_profile/siglip_vectors",
                                                     index_endpoint=INDEX_ENDPOINT, 
                                                     deployed_index_id=DEPLOYED_INDEX_ID, 
                                                     num_neighbors=5,
                                                     max_images_per_class=None):
    """
    Evaluate the vector search system using pre-computed embeddings from JSON files in GCS.
    This function uses embeddings directly instead of processing raw images.
    
    Args:
        test_data_path (str): Path to GCS bucket containing JSON files with embeddings
        index_endpoint (str): Vector search index endpoint
        deployed_index_id (str): Deployed index ID
        num_neighbors (int): Number of neighbors to retrieve
        max_images_per_class (int): Maximum number of images to test per class (None for all)
    
    Returns:
        dict: Evaluation results including accuracy metrics
    """
    from collections import defaultdict
    import time
    import json
    from google.cloud import storage
    from google.cloud import aiplatform_v1
    
    # Initialize results tracking
    results = {
        'total_queries': 0,
        'correct_predictions': 0,
        'class_results': defaultdict(lambda: {'total': 0, 'correct': 0, 'predictions': []}),
        'detailed_results': []
    }
    
    # Load test embeddings from GCS
    test_embeddings = []
    
    print(f"Loading test embeddings from GCS path: {test_data_path}")
    
    # Parse GCS path
    gcs_path = test_data_path.replace('gs://', '')
    bucket_name = gcs_path.split('/')[0]
    blob_prefix = '/'.join(gcs_path.split('/')[1:]) if len(gcs_path.split('/')) > 1 else ''
    
    # Initialize GCS client
    storage_client = storage.Client()
    bucket = storage_client.bucket(bucket_name)
    
    # List all JSON files in the GCS path
    blobs = bucket.list_blobs(prefix=blob_prefix)
    
    for blob in blobs:
        if blob.name.endswith('.json'):
            print(f"Processing GCS file: {blob.name}")
            
            try:
                # Download and parse JSON content
                content = blob.download_as_text()
                
                for line in content.splitlines():
                    if line.strip():
                        try:
                            embedding_data = json.loads(line)
                            
                            # Extract relevant information
                            original_path = embedding_data.get('original_path', '')
                            label_class_val = embedding_data.get('label_class', '')
                            embedding_vector = embedding_data.get('embedding', [])
                            data_id = embedding_data.get('id', '')
                            
                            # Filter for test images (assuming they contain 'test' in path)
                            if 'test' in original_path.lower() and embedding_vector:
                                test_embeddings.append({
                                    'path': original_path,
                                    'class': label_class_val,
                                    'id': data_id,
                                    'embedding': embedding_vector
                                })
                                
                        except json.JSONDecodeError as e:
                            print(f"Error parsing JSON line: {e}")
                            continue
                            
            except Exception as e:
                print(f"Error processing GCS blob {blob.name}: {e}")
                continue
    
    if not test_embeddings:
        raise ValueError("No test embeddings found. Please check your GCS path contains test data.")
    
    print(f"Found {len(test_embeddings)} test embeddings from GCS")
    
    # Group by class for max_images_per_class filtering
    embeddings_by_class = defaultdict(list)
    for emb_data in test_embeddings:
        embeddings_by_class[emb_data['class']].append(emb_data)
    
    # Apply max_images_per_class limit if specified
    if max_images_per_class:
        filtered_embeddings = []
        for class_name, class_embeddings in embeddings_by_class.items():
            filtered_embeddings.extend(class_embeddings[:max_images_per_class])
        test_embeddings = filtered_embeddings
    
    # Get unique classes
    equipment_classes = list(set([emb['class'] for emb in test_embeddings]))
    
    print(f"Found {len(equipment_classes)} equipment classes: {sorted(equipment_classes)}")
    print(f"Total test embeddings: {len(test_embeddings)}")
    print(f"Starting evaluation with {num_neighbors} neighbors per query...\n")
    
    # Process each test embedding
    for emb_data in test_embeddings:
        original_path = emb_data['path']
        true_class = emb_data['class']
        query_embedding = emb_data['embedding']
        
        try:
            # Create vector search request using pre-computed embedding
            datapoint = aiplatform_v1.IndexDatapoint(
                feature_vector=query_embedding
            )

            query = aiplatform_v1.FindNeighborsRequest.Query(
                datapoint=datapoint,
                neighbor_count=num_neighbors
            )

            request = aiplatform_v1.FindNeighborsRequest(
                index_endpoint=index_endpoint,
                deployed_index_id=deployed_index_id,
                queries=[query],
                return_full_datapoint=False,
            )

            # Execute the request
            response = vector_search_client.find_neighbors(request)
            
            if response and response.nearest_neighbors:
                # Get the top prediction
                neighbors = response.nearest_neighbors[0].neighbors
                
                # Extract predicted classes from neighbors
                predicted_classes = []
                neighbor_details = []
                
                for neighbor in neighbors:
                    neighbor_id = neighbor.datapoint.datapoint_id
                    distance = neighbor.distance
                    
                    # Get the predicted class from the neighbor ID
                    if neighbor_id in label_class:
                        pred_class = label_class[neighbor_id]
                        predicted_classes.append(pred_class)
                        
                        neighbor_details.append({
                            'id': neighbor_id,
                            'predicted_class': pred_class,
                            'distance': distance,
                            'original_path': original_path.get(neighbor_id, 'Unknown')
                        })
                
                # Determine if prediction is correct (top-1 and top-k accuracy)
                top1_correct = predicted_classes[0] == true_class if predicted_classes else False
                topk_correct = true_class in predicted_classes[:num_neighbors] if predicted_classes else False
                
                # Update results
                results['total_queries'] += 1
                results['class_results'][true_class]['total'] += 1
                
                if top1_correct:
                    results['correct_predictions'] += 1
                    results['class_results'][true_class]['correct'] += 1
                
                # Store detailed results
                result_detail = {
                    'query_image': original_path,
                    'true_class': true_class,
                    'predicted_classes': predicted_classes,
                    'top1_correct': top1_correct,
                    'topk_correct': topk_correct,
                    'neighbors': neighbor_details
                }
                results['detailed_results'].append(result_detail)
                results['class_results'][true_class]['predictions'].append(result_detail)
                
                # Print progress
                if results['total_queries'] % 10 == 0:
                    current_accuracy = (results['correct_predictions'] / results['total_queries']) * 100
                    print(f"Processed {results['total_queries']} embeddings, Current accuracy: {current_accuracy:.2f}%")
            
            else:
                print(f"No response for embedding: {original_path}")
                
            # Small delay to avoid overwhelming the API
            time.sleep(0.1)
            
        except Exception as e:
            print(f"Error processing embedding {original_path}: {e}")
            continue
    
    # Calculate final metrics
    overall_accuracy = (results['correct_predictions'] / results['total_queries']) * 100 if results['total_queries'] > 0 else 0
    
    # Calculate per-class accuracy
    class_accuracies = {}
    for class_name, class_data in results['class_results'].items():
        if class_data['total'] > 0:
            class_accuracies[class_name] = (class_data['correct'] / class_data['total']) * 100
        else:
            class_accuracies[class_name] = 0
    
    # Add summary to results
    results['summary'] = {
        'overall_accuracy': overall_accuracy,
        'class_accuracies': class_accuracies,
        'total_classes': len(equipment_classes),
        'avg_class_accuracy': sum(class_accuracies.values()) / len(class_accuracies) if class_accuracies else 0
    }
    
    return results

In [None]:
def display_evaluation_results(results):
    """
    Display comprehensive evaluation results in a formatted way.
    
    Args:
        results (dict): Results from evaluate_equipment_test_dataset function
    """
    print("=" * 80)
    print("EQUIPMENT PROFILE VECTOR SEARCH EVALUATION RESULTS")
    print("=" * 80)
    
    # Overall summary
    summary = results['summary']
    print(f"\nüìä OVERALL PERFORMANCE:")
    print(f"   Total Queries: {results['total_queries']}")
    print(f"   Correct Predictions: {results['correct_predictions']}")
    print(f"   Overall Accuracy: {summary['overall_accuracy']:.2f}%")
    print(f"   Average Class Accuracy: {summary['avg_class_accuracy']:.2f}%")
    print(f"   Total Classes: {summary['total_classes']}")
    
    # Per-class results
    print(f"\nüìã PER-CLASS ACCURACY:")
    class_results = results['class_results']
    class_accuracies = summary['class_accuracies']
    
    # Sort by accuracy for better visualization
    sorted_classes = sorted(class_accuracies.items(), key=lambda x: x[1], reverse=True)
    
    print(f"{'Class':<8} {'Accuracy':<10} {'Correct/Total':<15} {'Status'}")
    print("-" * 50)
    
    for class_name, accuracy in sorted_classes:
        total = class_results[class_name]['total']
        correct = class_results[class_name]['correct']
        status = "‚úÖ Excellent" if accuracy >= 90 else "‚ö†Ô∏è Good" if accuracy >= 70 else "‚ùå Needs Work"
        print(f"{class_name:<8} {accuracy:<10.2f}% {correct}/{total:<13} {status}")
    
    # Confusion analysis
    print(f"\nüîç CONFUSION ANALYSIS:")
    confusion_data = defaultdict(lambda: defaultdict(int))
    
    for detail in results['detailed_results']:
        true_class = detail['true_class']
        if detail['predicted_classes']:
            pred_class = detail['predicted_classes'][0]  # Top prediction
            confusion_data[true_class][pred_class] += 1
    
    # Show most common misclassifications
    print("Most common misclassifications:")
    misclassifications = []
    for true_class, predictions in confusion_data.items():
        for pred_class, count in predictions.items():
            if true_class != pred_class and count > 0:
                misclassifications.append((true_class, pred_class, count))
    
    # Sort by frequency
    misclassifications.sort(key=lambda x: x[2], reverse=True)
    
    if misclassifications:
        print(f"{'True Class':<10} {'Predicted As':<12} {'Count'}")
        print("-" * 35)
        for true_cls, pred_cls, count in misclassifications[:10]:  # Top 10 misclassifications
            print(f"{true_cls:<10} {pred_cls:<12} {count}")
    else:
        print("No misclassifications found!")
    
    # Performance insights
    print(f"\nüí° INSIGHTS:")
    best_class = max(class_accuracies.items(), key=lambda x: x[1])
    worst_class = min(class_accuracies.items(), key=lambda x: x[1])
    
    print(f"   Best performing class: {best_class[0]} ({best_class[1]:.2f}%)")
    print(f"   Worst performing class: {worst_class[0]} ({worst_class[1]:.2f}%)")
    
    accuracy_range = best_class[1] - worst_class[1]
    print(f"   Performance variance: {accuracy_range:.2f}%")
    
    if accuracy_range > 30:
        print("   ‚ö†Ô∏è High variance detected - some classes may need more training data")
    elif summary['overall_accuracy'] < 80:
        print("   ‚ö†Ô∏è Overall accuracy below 80% - consider model fine-tuning")
    else:
        print("   ‚úÖ Good performance consistency across classes")

def analyze_failed_predictions(results, class_name=None, top_n=5):
    """
    Analyze failed predictions for debugging purposes.
    
    Args:
        results (dict): Results from evaluate_equipment_test_dataset function
        class_name (str): Specific class to analyze (None for all classes)
        top_n (int): Number of failed cases to show
    """
    print(f"\nüîç ANALYZING FAILED PREDICTIONS:")
    
    failed_predictions = []
    for detail in results['detailed_results']:
        if not detail['top1_correct']:
            if class_name is None or detail['true_class'] == class_name:
                failed_predictions.append(detail)
    
    if not failed_predictions:
        print("No failed predictions found!")
        return
    
    print(f"Found {len(failed_predictions)} failed predictions")
    if class_name:
        print(f"Filtering by class: {class_name}")
    
    print(f"\nTop {min(top_n, len(failed_predictions))} failed cases:")
    print("-" * 80)
    
    for i, failure in enumerate(failed_predictions[:top_n]):
        print(f"\n{i+1}. Query: {failure['query_image']}")
        print(f"   True class: {failure['true_class']}")
        print(f"   Predicted: {failure['predicted_classes'][0] if failure['predicted_classes'] else 'No prediction'}")
        print(f"   Top predictions: {failure['predicted_classes'][:3]}")
        
        if failure['neighbors']:
            print(f"   Closest match distance: {failure['neighbors'][0]['distance']:.4f}")
            print(f"   Closest match path: {failure['neighbors'][0]['original_path']}")

In [None]:
# Example usage of evaluation functions with pre-computed embeddings from GCS

# Use pre-computed embeddings directly from GCS (much faster!)
print("Running evaluation using pre-computed embeddings from GCS (max 3 images per class)...")
quick_results = evaluate_equipment_test_dataset_from_embeddings(
    test_data_path="gs://axmt_equipment_profile/siglip_vectors",  # GCS path with embedding JSON files
    index_endpoint=INDEX_ENDPOINT,
    deployed_index_id=DEPLOYED_INDEX_ID,
    num_neighbors=5,
    max_images_per_class=3  # Limit for quick testing
)

# Display results
display_evaluation_results(quick_results)

In [None]:
# Full evaluation using pre-computed embeddings from GCS (uncomment to run complete evaluation)
# WARNING: This may take a long time depending on the dataset size

"""
# Run full evaluation using pre-computed embeddings from GCS
print("Running full evaluation using pre-computed embeddings from GCS...")
full_results = evaluate_equipment_test_dataset_from_embeddings(
    test_data_path="gs://axmt_equipment_profile/siglip_vectors",  # GCS path with embedding JSON files
    index_endpoint=INDEX_ENDPOINT,
    deployed_index_id=DEPLOYED_INDEX_ID,
    num_neighbors=5,
    max_images_per_class=None  # No limit - use all test embeddings
)

# Display comprehensive results
display_evaluation_results(full_results)

# Analyze failed predictions for specific classes
analyze_failed_predictions(full_results, class_name="AI1", top_n=5)
analyze_failed_predictions(full_results, class_name="AI2", top_n=5)

# Save results to file for later analysis
import json
with open('evaluation_results_embeddings.json', 'w') as f:
    # Convert defaultdict to regular dict for JSON serialization
    results_copy = dict(full_results)
    results_copy['class_results'] = dict(results_copy['class_results'])
    json.dump(results_copy, f, indent=2, default=str)
    
print("Results saved to evaluation_results_embeddings.json")
"""

print("Full evaluation code is commented out. Uncomment to run complete evaluation.")
print("Now uses pre-computed embeddings from GCS for much faster evaluation!")