In [1]:
import torch
import torch.onnx
import torchvision.models as models
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
from tests.config import TEST_MISC_DIR

# Load pre-trained ResNet-50 model
resnet = models.resnet50(pretrained=True)
resnet = torch.nn.Sequential(*(list(resnet.children())[:-1]))  # Remove the last fully connected layer
resnet.eval()

# Define preprocessing transform
preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load and preprocess the image
def preprocess_image(image_path):
    input_image = Image.open(image_path)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # Add batch dimension
    return input_batch

# Example input for exporting
input_image = preprocess_image('example.jpg')

# Export the model to ONNX with dynamic axes
torch.onnx.export(
    resnet, 
    input_image, 
    "model.onnx", 
    export_params=True, 
    opset_version=9, 
    input_names=['input'], 
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)

# Load ONNX model
import onnx
import onnxruntime as ort

onnx_model = onnx.load("model.onnx")
ort_session = ort.InferenceSession("model.onnx")

# Run inference and extract feature vectors
def extract_feature_vectors(image_paths):
    input_images = [preprocess_image(image_path) for image_path in image_paths]
    input_batch = torch.cat(input_images, dim=0)  # Combine images into a single batch
    ort_inputs = {ort_session.get_inputs()[0].name: input_batch.numpy()}
    ort_outs = ort_session.run(None, ort_inputs)
    return ort_outs[0]

# Example usage
images = [TEST_MISC_DIR / "image.jpeg", str(TEST_MISC_DIR / "small_image.jpeg")]  # Replace with your image paths
feature_vectors = extract_feature_vectors(images)
print("Feature vector shape:", feature_vectors.shape)


ModuleNotFoundError: No module named 'torch'