# [Reverse Image Search on Milvus](https://milvus.io/bootcamp)
##  Code ref: https://github.com/towhee-io/examples/tree/main/image/reverse_image_search

In [None]:
!pip install --trusted-host pypi.python.org --trusted-host pypi.org --trusted-host files.pythonhosted.org --upgrade pip

In [None]:
!pip install  --trusted-host pypi.python.org --trusted-host pypi.org --trusted-host files.pythonhosted.org \
msal \
retry \
simplejson \
openai \
grpcio==1.58.0 \
pymilvus==2.3.5 \
protobuf \
grpcio-tools==1.58.0 \
pymongo \
tiktoken \
scikit-learn \
towhee \
opencv-python \
pillow

In [None]:
import os
import openai
import requests
import simplejson as json
from retry import retry
from msal import PublicClientApplication, ConfidentialClientApplication, ClientApplication
import time

In [None]:
@retry (tries=3, delay=2)
def getapikey():
    # These are the AZURE parameters needed by the client application
    # In this scheme, the user does not have to worry about the api key, all of that is handled
    # at the AZURE Back End.   The API key is rotated hourly.
    # We can handle these as Kuberbetes secrets, or as environment variables.
    client_id = os.getenv("AZURE_CLIENT_ID")
    tenant_id = os.getenv("AZURE_TENANT_ID")
    endpoint = os.getenv("AZURE_ENDPOINT")

    scopes = [os.getenv("AZURE_APPLICATION_SCOPE")]

    app = ClientApplication(
        client_id=client_id,
        authority="https://login.microsoftonline.com/" + tenant_id
    )


    acquire_tokens_result = app.acquire_token_by_username_password(username=os.getenv("SVC_ACCOUNT"),
                                                                   password=os.getenv("SVC_PASSWORD"),
                                                                   scopes=scopes)
    if 'error' in acquire_tokens_result:
        print("Error: " + acquire_tokens_result['error'])
        print("Description: " + acquire_tokens_result['error_description'])
        return 2
    else:
        header_token = {"Authorization": "Bearer {}".format(acquire_tokens_result['access_token'])}
        rt = requests.post(url=endpoint, headers=header_token, data=b'{"key":"openaikey2"}')
        return rt.json()

In [None]:
openai.api_type = os.getenv("OPENAI_API_TYPE")
openai.api_base = os.getenv("OPENAI_API_BASE")
openai.api_version = os.getenv("OPENAI_API_VERSION")
# We are dynamically getting the key from AZURE>   Access is based on the service account/ad group combination
openai.api_key = getapikey()
print(openai.api_key)

In [None]:
import os
from openai import AzureOpenAI

ai_client = AzureOpenAI(
  api_key = openai.api_key,  
  api_version = os.getenv("OPENAI_API_VERSION"),
  azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
)


# Test image retrieval using multi-modal embeddings
## Reference docs: https://learn.microsoft.com/en-us/azure/ai-services/computer-vision/how-to/image-retrieval?tabs=csharp

In [None]:
import os
from dotenv import load_dotenv
from PIL import Image
import matplotlib.pyplot as plt
import glob
import json
import pathlib
from azurecv import (image_embedding, text_embedding, get_cosine_similarity, 
                     get_image_embedding_multiprocessing, 
                     search_by_image, search_by_text)

# image embedding
endpoint = f"{os.getenv('AZURE_OPENAI_ENDPOINT')}computervision/"
image = "data/images/train/banana/n07753592_18.JPEG"
image_emb = image_embedding(image, endpoint, openai.api_key)
print(image_emb)

In [None]:
import csv
from glob import glob
from pathlib import Path
from statistics import mean

from towhee import pipe, ops, DataCollection
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility


# Towhee parameters
MODEL = 'resnet50'
DEVICE = None # if None, use default device (cuda is enabled if available)

# Milvus parameters
HOST = 'standalone'
PORT = '19530'
TOPK = 10
DIM = 2048 # dimension of embedding extracted by MODEL
COLLECTION_NAME = 'reverse_image_search'
INDEX_TYPE = 'IVF_FLAT'
METRIC_TYPE = 'L2'

# path to csv (column_1 indicates image path) OR a pattern of image paths
INSERT_SRC = 'reverse_image_search.csv'
QUERY_SRC = './test/*/*.JPEG'

In [None]:
# Load image path
def load_image(x):
    if x.endswith('csv'):
        with open(x) as f:
            reader = csv.reader(f)
            next(reader)
            for item in reader:
                yield item[1]
    else:
        for item in glob(x):
            yield item
            
# Embedding pipeline
p_embed = (
    pipe.input('src')
        .flat_map('src', 'img_path', load_image)
        .map('img_path', 'img', ops.image_decode())
        .map('img', 'vec', ops.image_embedding.timm(model_name=MODEL, device=DEVICE))
)

In [None]:
# Display embedding result, no need for implementation
p_display = p_embed.output('img_path', 'img', 'vec')
DataCollection(p_display('./data/images/test/goldfish/*.JPEG')).show()

In [None]:
# Create milvus collection (delete first if exists)
def create_milvus_collection(collection_name, dim):
    if utility.has_collection(collection_name):
        utility.drop_collection(collection_name)
    
    fields = [
        FieldSchema(name='path', dtype=DataType.VARCHAR, description='path to image', max_length=500, 
                    is_primary=True, auto_id=False),
        FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='image embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    index_params = {
        'metric_type': METRIC_TYPE,
        'index_type': INDEX_TYPE,
        'params': {"nlist": 2048}
    }
    collection.create_index(field_name='embedding', index_params=index_params)

In [None]:
# Connect to Milvus service
connections.connect(host=HOST, port=PORT)

# Create collection
collection = create_milvus_collection(COLLECTION_NAME, DIM)
print(f'A new collection created: {COLLECTION_NAME}')