In [None]:
%pip install pymilvus torch gdown torchvision tqdm

In [None]:
%pip install matplotlib

In [None]:
import gdown
import zipfile

# Download the dataset
url = 'https://drive.google.com/uc?id=1OYDHLEy992qu5C4C8HV5uDIkOWRTAR1_'
output = './paintings.zip'
gdown.download(url, output)

with zipfile.ZipFile("./paintings.zip","r") as zip_ref:
    zip_ref.extractall("./paintings") # Extract the dataset

In [28]:
# Define the parameters
# COLLECTION_NAME = 'image_search'  # Collection name

# Define the collection name
COLLECTION_NAME = 'image_search_test'  # Collection name

DIMENSION = 2048  # Embedding vector size in this example
MILVUS_HOST = "localhost"
MILVUS_PORT = "19530"

BATCH_SIZE = 128
# TOP_K = 3
TOP_K = 6

In [29]:
# Connect to the Milvus instance using the provided URI.
from pymilvus import connections
connections.connect(host=MILVUS_HOST, port=MILVUS_PORT)

In [30]:
# If the collection already exists, drop it.
from pymilvus import utility

if utility.has_collection(COLLECTION_NAME):
    print(f"Dropping collection: {COLLECTION_NAME}")
    utility.drop_collection(COLLECTION_NAME)

In [31]:
# Create the collection that holds the ID, the file path of the image, and its embedding.

from pymilvus import FieldSchema, CollectionSchema, DataType, Collection

fields = [
    FieldSchema(name='id', dtype=DataType.INT64, is_primary=True, auto_id=True),
    FieldSchema(name='filepath', dtype=DataType.VARCHAR, max_length=200),  # VARCHARS need a maximum length, so for this example they are set to 200 characters
    FieldSchema(name='image_embedding', dtype=DataType.FLOAT_VECTOR, dim=DIMENSION)
]
schema = CollectionSchema(fields=fields)
collection = Collection(name=COLLECTION_NAME, schema=schema)

In [32]:
# Create an index on the newly created collection and load it into memory.
index_params = {
'metric_type':'L2',
'index_type':"IVF_FLAT",
'params':{'nlist': 16384}
}
collection.create_index(field_name="image_embedding", index_params=index_params)
collection.load()

# 3.Inserting Data

In [None]:
# Loading the data.
import glob

# paths = glob.glob('./image-search-test/**/*.jpg', recursive=True)
paths = glob.glob('../clothing-images/**/*.jpg', recursive=True)
len(paths)

In [None]:
# Preprocessing the data into batches.
import torch

model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet50', pretrained=True)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
model.eval()

In [36]:
# Embedding the data.
from torchvision import transforms

preprocess = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [None]:
# Inserting the data.
from PIL import Image
from tqdm import tqdm

def embed(data):
    with torch.no_grad():
        output = model(torch.stack(data[0])).squeeze()
        collection.insert([data[1], output.tolist()])

data_batch = [[],[]]

for path in tqdm(paths):
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)
    if len(data_batch[0]) % BATCH_SIZE == 0:
        embed(data_batch)
        data_batch = [[],[]]

if len(data_batch[0]) != 0:
    embed(data_batch)

collection.flush()

# 4. Performing the search

In [None]:
import glob

# Define the search paths and the image to search for
search_paths = glob.glob('./image-search-test/042/0422106042.jpg', recursive=True)
# search_paths = glob.glob('./test.jpeg', recursive=True)
len(search_paths)

In [16]:
import time
from matplotlib import pyplot as plt

def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()


In [None]:
import glob

# Define the search paths and the image to search for
search_paths = glob.glob('./image-search-test/049/0490113004.jpg', recursive=True)
# search_paths = glob.glob('./test.jpeg', recursive=True)
len(search_paths)

import time
from matplotlib import pyplot as plt

def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()

f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    axarr[hits_i][0].set_title('Search Time: ' + str(finish - start))
    for hit_i, hit in enumerate(hits):
        axarr[hits_i][hit_i + 1].imshow(Image.open(hit.entity.get('filepath')))
        axarr[hits_i][hit_i + 1].set_axis_off()
        axarr[hits_i][hit_i + 1].set_title('Distance: ' + str(hit.distance))

plt.savefig('search_result.png')

In [None]:
import glob

# Define the search paths and the image to search for
search_paths = glob.glob('../clothing-images/088/0880060002.jpg', recursive=True)
# search_paths = glob.glob('./image-search-test/050/0501406005.jpg', recursive=True)
# search_paths = glob.glob('./test.jpeg', recursive=True)
len(search_paths)

import time
from matplotlib import pyplot as plt

def embed(data):
    with torch.no_grad():
        ret = model(torch.stack(data))
        # If more than one image, use squeeze
        if len(ret) > 1:
            return ret.squeeze().tolist()
        # Squeeze would remove batch for single image, so using flatten
        else:
            return torch.flatten(ret, start_dim=1).tolist()

data_batch = [[],[]]

for path in search_paths:
    im = Image.open(path).convert('RGB')
    data_batch[0].append(preprocess(im))
    data_batch[1].append(path)

embeds = embed(data_batch[0])
start = time.time()
res = collection.search(embeds, anns_field='image_embedding', param={'nprobe': 128}, limit=TOP_K, output_fields=['filepath'])
finish = time.time()

import os

filename = 'search_result.png'
f, axarr = plt.subplots(len(data_batch[1]), TOP_K + 1, figsize=(20, 10), squeeze=False)

for hits_i, hits in enumerate(res):
    # Show the query image
    axarr[hits_i][0].imshow(Image.open(data_batch[1][hits_i]))
    axarr[hits_i][0].set_axis_off()
    query_image_filename = os.path.basename(data_batch[1][hits_i])  # Get the query image filename
    axarr[hits_i][0].set_title(f'Query: {query_image_filename}')
    
    # Show the found images and their distances
    for hit_i, hit in enumerate(hits):
        found_image_path = hit.entity.get('filepath')
        axarr[hits_i][hit_i + 1].imshow(Image.open(found_image_path))
        axarr[hits_i][hit_i + 1].set_axis_off()
        found_image_filename = os.path.basename(found_image_path)  # Get the found image filename
        axarr[hits_i][hit_i + 1].set_title(f'{found_image_filename}\nDistance: {hit.distance}')

# Save the result figure
plt.savefig(filename)
