# Multimodal search with CLIP

In [None]:
import clip
import pymongo
from pinnacledb.misc.pinnacle import pinnacle
from pinnacledb.models.torch.wrapper import TorchModel
from pinnacledb.datalayer.mongodb.query import Collection
from pinnacledb.core.documents import Document as D
from pinnacledb.encoders.pillow.image import pil_image as i

pymongo.MongoClient().drop_database('documents')
pymongo.MongoClient().drop_database('_filesystem:documents')

db = pymongo.MongoClient().documents
db = pinnacle(db)

collection = Collection(name='tiny-imagenet')

In [None]:
from pinnacledb.core.documents import Document as D
from pinnacledb.encoders.pillow.image import pil_image as i
from datasets import load_dataset
import random

dataset = load_dataset("zh-plus/tiny-imagenet")['valid']
dataset = [D({'image': i(r['image'])}) for r in dataset]
dataset = random.sample(dataset, 1000)

In [None]:
db.execute(collection.insert_many(dataset, encoders=(i,)))

In [None]:
x = db.execute(collection.find_one())['image'].x

In [None]:
model, preprocess = clip.load("ViT-B/32", device='cpu')

In [None]:
from pinnacledb.encoders.torch.tensor import tensor
import torch

t = tensor(torch.float, shape=(512,))

In [None]:
text_model = TorchModel(
    identifier='clip_text',
    object=model,
    preprocess=lambda x: clip.tokenize(x)[0],
    forward_method='encode_text',
    encoder=t
)

In [None]:
text_model.predict('this is a test', one=True)

In [None]:
visual_model = TorchModel(
    identifier='clip_image',
    preprocess=preprocess,
    object=model.visual,
    encoder=t,
)

In [None]:
visual_model.predict(x, one=True)

In [None]:
from pinnacledb.core.vector_index import VectorIndex
from pinnacledb.core.watcher import Watcher

db.add(
    VectorIndex(
        'my-index',
        indexing_watcher=Watcher(
            model=visual_model,
            key='image',
            select=collection.find(),
        ),
        compatible_watcher=Watcher(
            model=text_model,
            key='text',
            active=False,
        )
    )
)

In [None]:
out = db.execute(
    collection.like(D({'text': 'mushroom'}), vector_index='my-index', n=3).find({})
)

In [None]:
from IPython.display import display
for r in out:
    display(r['image'].x)