# Towhee image search

## Preparetion

In [2]:
from pymilvus import connections, FieldSchema, CollectionSchema, DataType, Collection, utility

In [3]:
# create reverse_image_search collection and build index
def create_milvus_collection(collection_name, dim):
    connections.connect(host='127.0.0.1', port='19530')
    fields = [
    FieldSchema(name='id', dtype=DataType.INT64, descrition='ids', is_primary=True, auto_id=True),
    FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, descrition='embedding vectors', dim=dim)
    ]
    schema = CollectionSchema(fields=fields, description='reverse image search')
    collection = Collection(name=collection_name, schema=schema)

    # create IVF_FLAT index for collection.
    index_params = {
        'metric_type':'L2',
        'index_type':"IVF_FLAT",
        'params':{"nlist":2048}
    }
    collection.create_index(field_name="embedding", index_params=index_params)
    return collection

# connect to milvus
collection = create_milvus_collection('reverse_image_search', 2048)

In [4]:
connections.connect(host='127.0.0.1', port='19530')
collection.num_entities
# Collection('reverse_image_search').drop()

1000

In [11]:
connections.connect(host='127.0.0.1', port='19530')
# Collection('resnet50').num_entities
# Collection('resnet101').num_entities
Collection('reverse_image_search_bak').num_entities
# Collection('resnet50').drop()
# Collection('resnet101').drop()

2000

## Reverse image search with Towhee

In [6]:
import towhee
from diskcache import Cache

milvus_collection = towhee.connectors.milvus(uri='tcp://127.0.0.1:19530/reverse_image_search')
cache = Cache('./tmp')

@towhee.register(name='save_path_id')
def save_path_id(path, mr):
    for i in range(len(mr.primary_keys)):
        cache[mr.primary_keys[i]] = path[i]
    return len(cache)

In [None]:
import time
time1 = time.time()

with milvus_collection:
    dc = (
        towhee.glob['path']('./extracted_train/*/*.JPEG')
#         towhee.glob['path']('/Users/chenshiyu/workspace/data/pic/tets/*.jpg')
          .image_decode['path', 'img']()
          .runas_op['path', 'label'](func = lambda path: path.split('/')[-2])
          .image_embedding.timm['img', 'vec'](model_name='resnet50')
          .batch(100)
          .ann_insert['vec', 'res'](ann_index=milvus_collection)
          .save_path_id[('path', 'res'), 'num']()
          .unstream()
    )
print(time.time()-time1)

In [None]:
import cv2
from towhee._types.image import Image

with milvus_collection:
    (
  towhee.glob['path']('./extracted_test/n043*/*.JPEG')
#   towhee.glob['path']('/Users/chenshiyu/workspace/data/pic/test/*.jpg')
        .image_decode['path', 'img']()
        .image_embedding.timm['img', 'vec'](model_name='resnet50')
        .ann_search['vec', 'result'](ann_index=milvus_collection, limit=5)
        .runas_op['result', 'result_img'](func=lambda res: [Image(cv2.imread(cache[x.id]), 'BGR') for x in res])
        .select['img', 'result_img']()
        .show()
    )

## Play with gradio

In [None]:
from towhee.types.image_utils import from_pil
import gradio

with towhee.api() as api:
    search_in_milvus = (
        api.runas_op(func=lambda img: from_pil(img))
        .image_embedding.timm(model_name='resnet50')
        .ann_search(ann_index='tcp://127.0.0.1:19530/reverse_image_search', limit=5)
        .runas_op(func=lambda res: [cache[x.id] for x in res])
        .as_function()
        )

interface = gradio.Interface(search_in_milvus, 
                             gradio.inputs.Image(type="pil", source='upload'),
                             [gradio.outputs.Image(type="file", label=None) for _ in range(5)]
                            )

interface.launch(inline=True)

## Advanced Test

In [7]:
# parallel execute
collection_res50_parallel = create_milvus_collection('reverse_image_search_bak', 2048)
milvus_collection = towhee.connectors.milvus(uri='tcp://127.0.0.1:19530/reverse_image_search_bak')

In [10]:
time1 = time.time()

with milvus_collection:
    dc = (
        towhee.glob['path']('./extracted_train/*/*.JPEG')
          .set_parallel(3)
          .image_decode['path', 'img']()
          .runas_op['path', 'label'](func = lambda path: path.split('/')[-2])
          .image_embedding.timm['img', 'vec'](model_name='resnet50')
          .batch(100)
          .ann_insert['vec', 'res'](ann_index=milvus_collection)
          .save_path_id[('path', 'res'), 'num']()
          .unstream()
    )
print(time.time()-time1)

70.05783486366272


In [None]:
# expetion safe and drop empty
with milvus_collection:
    (
  towhee.glob['path']('./exception/*.JPEG')
        .exception_safe()
        .image_decode['path', 'img']()
        .image_embedding.timm['img', 'vec'](model_name='resnet50')
        .ann_search['vec', 'result'](ann_index=milvus_collection, limit=5)
        .runas_op['result', 'result_img'](func=lambda res: [Image(cv2.imread(cache[x.id]), 'BGR') for x in res])
        .drop_empty()
        .select['img', 'result_img']()
        .show()
    )

In [None]:
# metric report
# collect label info
labels = {}
def collect(entity):
    for label, path in zip(entity.label, entity.path):
        if label not in labels:
            labels[label] = [path]
        else:
            labels[label].append(path)
any(map(collect, dc))

model_dim = {
    'resnet50': 2048,
    'resnet101': 2048
}

for model in model_dim:
    collection = create_milvus_collection(model, model_dim[model])
    milvus_collection = towhee.connectors.milvus(uri=f'tcp://127.0.0.1:19530/{model}')
    
    with milvus_collection:
        ( towhee.glob['path']('./extracted_train/*/*.JPEG')
                .set_parallel(3)
                .image_decode['path', 'img']()
                .image_embedding.timm['img', 'vec'](model_name=model)
                .batch(100)
                .ann_insert['vec', 'res'](ann_index=milvus_collection)
                .save_path_id[('path', 'res'), 'num']()
                .run()
        )

        ( towhee.glob['path']('./extracted_test/*/*.JPEG')
                .set_parallel(3)
                .image_decode['path', 'img']()
                .image_embedding.timm['img', 'vec'](model_name=model)
                .runas_op['path', 'ground_truth'](func=lambda path: labels[path.split('/')[-2]])
                .ann_search['vec', 'result'](ann_index=milvus_collection, limit=5)
                .runas_op['result', 'result'](func=lambda res: [cache[x.id] for x in res])
                .with_metrics(['mean_hit_ratio', 'mean_average_precision'])
                .evaluate['ground_truth', 'result'](model)
                .report()
        )

In [None]:
# Define a Huggingface ViT Model
from transformers import ViTFeatureExtractor, ViTModel

def vit_embedding_model(img):
    img = img.cv2_to_rgb()
    feature_extractor = ViTFeatureExtractor.from_pretrained('google/vit-base-patch16-224-in21k')
    model = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

    inputs = feature_extractor(img, return_tensors="pt")
    outputs = model(**inputs)

    return outputs.pooler_output.detach().numpy().flatten()

# Image classification with huggingface vit model
(   
    towhee.glob['path']('./extracted_train/*/*.JPEG')
        .exception_safe()
        .set_parallel(4)
        .image_decode['path', 'img']()
        .runas_op['img', 'vec'](func=vit_embedding_model)
        .batch(100)
        .ann_insert['vec', 'res'](ann_index=milvus_collection)
        .save_path_id[('path', 'res'), 'num']()
        .drop_empty()
        .run()
)

(
    towhee.glob['path']('./extracted_test/*/*.JPEG')
        .exception_safe()
        .set_parallel(4)
        .image_decode['path', 'img']()
        .runas_op['img', 'vec'](func=vit_embedding_model)
        .runas_op['path', 'ground_truth'](func=lambda path: labels[path.split('/')[-2]])
        .ann_search['vec', 'result'](ann_index=milvus_collection, limit=5)
        .runas_op['result', 'result'](func=lambda res: [cache[x.id] for x in res])
        .with_metrics(['mean_hit_ratio', 'mean_average_precision'])
        .drop_empty()
        .evaluate['ground_truth', 'result']('huggingface_vit')
        .report()
)

## Run with FastAPI

In [None]:
from fastapi import FastAPI
app = FastAPI()
milvus_collection = towhee.connectors.milvus(uri='tcp://127.0.0.1:19530/resnet101')
    
with towhee.api['file']() as api:
    with milvus_collection:
        app_insert = (
            api.image_load['file', 'img']()
              .save_image['img', 'path'](dir='tmp/images')
              .image_embedding.timm['img', 'vec'](model_name='resnet101')
              .ann_insert['vec', 'res'](ann_index=milvus_collection)
              .save_path_id[('path', 'res'), 'num']()
              .select['res']()
              .serve('/insert', app)
        )

In [None]:
with towhee.api['file']() as api:
    with milvus_collection:
        (
         api.image_load['file', 'img']()
            .image_embedding.timm['img', 'vec'](model_name='resnet101')
            .ann_search['vec', 'result'](ann_index=milvus_collection)
            .runas_op['result', 'file'](func=lambda res: [cache[x.id] for x in res])
            .select['file']()
            .serve('/search', app)
        )

In [None]:
@towhee.register(name='milvus-count')
class MilvusCount:
    def __init__(self, collection):
        self.collection = collection
        if isinstance(collection, str):
            self.collection = Collection(collection)

    def __call__(self, *args):
        return self.collection.num_entities


with towhee.api() as api:
    app_count = (
        api.milvus_count(collection='resnet101')
        .serve('/count', app)
        )

In [None]:
import uvicorn
import nest_asyncio

nest_asyncio.apply()
uvicorn.run(app=app, host='0.0.0.0', port=8000)

Then try to run these api with command:
```shell
# upload an image and search
$ curl -X POST "http://0.0.0.0:8000/search"  --data-binary @extracted_test/n01443537/n01443537_3883.JPEG -H 'Content-Type: image/jpeg'
# upload an image and insert
$ curl -X POST "http://0.0.0.0:8000/insert"  --data-binary @extracted_test/n01443537/n01443537_3883.JPEG -H 'Content-Type: image/jpeg'
# count the collection
$ curl -X POST "http://0.0.0.0:8000/count"
```