<a href="https://colab.research.google.com/github/osaeed-ds/astrapy-imagesearch/blob/main/ImageSearchAstraPy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## **Introduction**

In this exercise we are going to adapt a notebook that connected to AstraDB via CassIO to instead use the AstraPy / document API.

This notebook adapts the Multimodal Image Search demo by Mukundha
https://github.com/mukundha/multi-modal-vector-retrieval-astra


## **Prerequisites Setup**

* Follow [these steps](https://docs.datastax.com/en/astra-serverless/docs/vector-search/overview.html) to create a new vector search enabled database in Astra.
* Enable the Preview Developer Experience
* From the Database Details on the right side of your database screen:
** generate and save an Application Token
** save the API Endpoint
* You can download the full flickr 8k dataset from kaggle
https://www.kaggle.com/datasets/adityajn105/flickr8k
but for speed in this repo uses a subset of around 100 images and their corresponding captions.  You will need to download that zip and either upload to colab or unzip locally.  The subset zip is found named flickr.tar.gz in https://github.com/osaeed-ds/astrapy-imagesearch


In [None]:
!pip install pandas torch pillow langchain astrapy git+https://github.com/openai/CLIP.git

In [49]:
import os
import torch
import clip
from PIL import Image
import pandas as pd
from astrapy.db import AstraDB
from getpass import getpass

## **Setup AstraDB Connection and Create Collection**


In [None]:
# Input your API ENDPOINT:
ASTRA_DB_API_ENDPOINT = input('Your collection name (e.g. UUID-REGION.apps.astra.datastax.com): ')

In [None]:
# Input your Astra DB Application Token string, the one starting with "AstraCS:..."
ASTRA_DB_TOKEN_BASED_PASSWORD = getpass('Your AstraDB Application Token (starts with AstraCS): ')

In [None]:
# Input your collection name:
ASTRA_DB_COLLECTION = input('Your collection name (e.g. image_vectors): ')

In [66]:
# Input your Vector's Dimensionality (e.g. 512):
ASTRA_DB_VECTOR_DIMENSIONALITY = input('The dimensionality of your emedding vector (e.g. 512 for the CLIP ViT-B/32 embedding model): ')

The dimensionality of your emedding vector (e.g. 512 for the CLIP ViT-B/32 embedding model): 512


In [None]:
db = AstraDB(
  token=ASTRA_DB_TOKEN_BASED_PASSWORD,
  api_endpoint=ASTRA_DB_API_ENDPOINT,
)

# Create collection
col = db.create_collection(ASTRA_DB_COLLECTION, dimension=ASTRA_DB_VECTOR_DIMENSIONALITY)

## **Setup AstraDB Connection and Create Collection**


We are going to use CLIP from OpenAI.  Note there is a separate python module named clip, and you will get errors if you pip install clip instead of the specific clip from OpenAI

In [55]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)



## **Load the captions**
Load the captions for each image.  The CLIP model will take as inputs the image and caption to generate the embeddings.  This will enable search for vector searches of images either by caption or by a similar image.


In [56]:
df = pd.read_csv('flickr/captions.txt')
df

Unnamed: 0,image,caption
0,101654506_8eb26cfb60.jpg,A brown and white dog is running through the s...
1,101654506_8eb26cfb60.jpg,A dog is running in the snow
2,101654506_8eb26cfb60.jpg,A dog running through snow .
3,101654506_8eb26cfb60.jpg,a white and brown dog is running through a sno...
4,101654506_8eb26cfb60.jpg,The white and brown dog is running over the su...
...,...,...
460,99679241_adc853a5c0.jpg,A grey bird stands majestically on a beach whi...
461,99679241_adc853a5c0.jpg,A large bird stands in the water on the beach .
462,99679241_adc853a5c0.jpg,A tall bird is standing on the sand beside the...
463,99679241_adc853a5c0.jpg,A water bird standing at the ocean 's edge .


## **Helper functions to generate embeddings**


In [57]:
# Based on this paper
# https://ai.meta.com/research/publications/scaling-autoregressive-multi-modal-models-pretraining-and-instruction-tuning/
def get_clip_embedding(text, image_path):
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)
    text = clip.tokenize(text,truncate=True).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    averaged_features = (image_features + text_features) / 2
    return averaged_features.numpy().tolist()

def embed_query(q):
    query_embed = clip.tokenize(query,truncate=True).to(device)
    with torch.no_grad():
        text_features = model.encode_text(query_embed)
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.numpy().tolist()[0]

def embed_image(image_path):
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.numpy().tolist()[0]


## **Generate Vectors and Load the data into AstraDB**
Loading all flickr data
This step is a bit inefficient.  It goes ahead and deletes the existing document by rowid and then inserts a new document one at a time.  The API does support inserting multiple documents at once as well as updates.

In [None]:
for index, row in df.iterrows():
    response = col.delete(id=index)
    image_url = f'{os.getcwd()}/flickr/Images/{row["image"]}'
    caption = row['caption']
    vector=get_clip_embedding(caption,image_url)[0]
    col.insert_one(
        {
            "_id": index,
            "image_url": image_url,
            "caption": caption,
            "$vector": vector
        }
    )
    print(index, image_url, caption, vector)

## **Text Query Search**


In [41]:
#text query search
query = "boy running outside"
results = col.vector_find(embed_query(query), limit=3)
for r in results:
    print(r)



{'_id': 49, 'image_url': '/content/flickr/Images/106490881_5a2dd9b7bd.jpg', 'caption': 'The boy is playing on the shore of an ocean .', '$vector': [0.005147820804268122, 0.02183985337615013, -0.017937414348125458, 0.0008977535180747509, 0.03262048587203026, -0.011528102681040764, -0.010737194679677486, -0.05298120900988579, 0.003845768980681896, 0.009611343033611774, 0.03429984301328659, -0.00736296596005559, 0.0033231941051781178, 0.015889972448349, 0.006193873472511768, 0.0019581345841288567, 0.010542633011937141, -0.015687664970755577, 0.026152659207582474, -0.014370724558830261, -0.02665530890226364, 0.02697252854704857, 0.03069976158440113, 0.028646482154726982, 0.001870188396424055, 0.021067677065730095, -0.03035666234791279, -0.004331015516072512, -0.03901532292366028, 0.0031969482079148293, 0.02209995873272419, 0.020709745585918427, 0.04479292407631874, -0.0041099414229393005, -0.008318895474076271, -0.006428428925573826, -0.004477131180465221, 0.050930581986904144, 0.008684108

## **Image Query Search**


In [42]:
#image query search
inp_img = f'{os.getcwd()}/flickr/Images/55135290_9bed5c4ca3.jpg'
print(inp_img)
results = col.vector_find(embed_image(inp_img), limit=3)
for r in results:
    print(r)

/content/flickr/Images/55135290_9bed5c4ca3.jpg
{'_id': 240, 'image_url': '/content/flickr/Images/55135290_9bed5c4ca3.jpg', 'caption': 'A boy wearing a red shirt and jeans is doing a flip on his bike .', '$vector': [0.0036213360726833344, 0.062300097197294235, 0.0039147078059613705, 0.003227377776056528, -0.05916713923215866, 0.002445925958454609, -0.03075273707509041, -0.011133064515888691, -0.009737602435052395, 0.017601672559976578, 0.03503865376114845, -0.04478851705789566, 0.01689887046813965, -0.0025414451956748962, -0.004949293099343777, 0.05815904960036278, 0.018911991268396378, 0.03127561882138252, -0.009063203819096088, -0.013836313039064407, -0.035924945026636124, -0.005659796297550201, 0.013921537436544895, 0.005184827372431755, -0.028421062976121902, 0.021491404622793198, -0.042231470346450806, -0.01567525416612625, 0.018319757655262947, -0.004222124814987183, -0.00018711481243371964, 0.010240093804895878, 0.013140898197889328, -0.011014388874173164, -0.0019277725368738174,

## **Delete Collection**


In [None]:
#delete collection
db.delete_collection(collection_name=ASTRA_DB_COLLECTION)