# Multimodal (Image + Text) Retrieval for Ecommerce

The purpose of this demo is to show how to run multi modal search based on Vector Similarity.

By Multi-Modal, we mean  the ability to encode texts and images in the same vector space.

Dataset available at: https://www.kaggle.com/datasets/PromptCloudHQ/flipkart-products

References: https://akgeni.medium.com/understanding-openai-clip-its-applications-452bd214e226

## Dependencies 

This demos is built on OpenAI/CLIP. It is necessary install it from the github repo.

In [None]:
!pip install pandas astrapy --quiet

In [None]:
!pip install git+https://github.com/openai/CLIP.git

In [1]:
import cassio
import torch
import clip
from PIL import Image
import pandas as pd 
from IPython.display import Image as ShowImage, display
import numpy as np

In [2]:
from dotenv import load_dotenv, find_dotenv
import os
load_dotenv(find_dotenv(), override=True)

True

## AstraDB Connection

We will use AstraDB, a DBaaS based on Cassandra with built-in Vector Search, to store and query the embieddings.

To open the connection and create the table that will store the data, we will use **CassIO# library** [www.cassio.org].


In [None]:
#token = os.environ['ASTRA_DB_APPLICATION_TOKEN']
#database_id = os.environ['ASTRA_DB_ID']
#keyspace = os.environ.get('ASTRA_DB_KEYSPACE')

In [3]:
from cqlsession import getCQLSession, getCQLKeyspace
cqlMode = 'dse' 
session = getCQLSession()
keyspace = getCQLKeyspace()

In [4]:
# Open the connection
#cassio.init(token=token, database_id=database_id)
#Create the table
v_store = cassio.table.MetadataVectorCassandraTable(session=session,
                                                    keyspace=keyspace,
                                                    table="demo_ecommerce", vector_dimension=512)

## The Model

In this case, we will use the OpenAI/CLIP model

![image.png](attachment:b4471e54-14ee-4731-9e99-b51eb46b3d27.png)

In [5]:
# Checking if GPU is available. Then loading the model.
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)


# Dataset

The dataset used on this demo is available on Kaggle. It contains 20000 products, eith theire names, descriptions, categories and images URL.

We will have to download the images, as the files are not available on Kaggle. This step require time because the download is not fast.

In [6]:
df = pd.read_csv('/Users/samuel.matioli/work/gen-ai/multimodal/flipkart/flipkart_com-ecommerce_sample.csv')
df.head()

Unnamed: 0,uniq_id,crawl_timestamp,product_url,product_name,product_category_tree,pid,retail_price,discounted_price,image,is_FK_Advantage_product,description,product_rating,overall_rating,brand,product_specifications
0,c2d766ca982eca8304150849735ffef9,2016-03-25 22:59:23 +0000,http://www.flipkart.com/alisha-solid-women-s-c...,Alisha Solid Women's Cycling Shorts,"[""Clothing >> Women's Clothing >> Lingerie, Sl...",SRTEH2FF9KEDEFGF,999.0,379.0,"[""http://img5a.flixcart.com/image/short/u/4/a/...",False,Key Features of Alisha Solid Women's Cycling S...,No rating available,No rating available,Alisha,"{""product_specification""=>[{""key""=>""Number of ..."
1,7f7036a6d550aaa89d34c77bd39a5e48,2016-03-25 22:59:23 +0000,http://www.flipkart.com/fabhomedecor-fabric-do...,FabHomeDecor Fabric Double Sofa Bed,"[""Furniture >> Living Room Furniture >> Sofa B...",SBEEH3QGU7MFYJFY,32157.0,22646.0,"[""http://img6a.flixcart.com/image/sofa-bed/j/f...",False,FabHomeDecor Fabric Double Sofa Bed (Finish Co...,No rating available,No rating available,FabHomeDecor,"{""product_specification""=>[{""key""=>""Installati..."
2,f449ec65dcbc041b6ae5e6a32717d01b,2016-03-25 22:59:23 +0000,http://www.flipkart.com/aw-bellies/p/itmeh4grg...,AW Bellies,"[""Footwear >> Women's Footwear >> Ballerinas >...",SHOEH4GRSUBJGZXE,999.0,499.0,"[""http://img5a.flixcart.com/image/shoe/7/z/z/r...",False,Key Features of AW Bellies Sandals Wedges Heel...,No rating available,No rating available,AW,"{""product_specification""=>[{""key""=>""Ideal For""..."
3,0973b37acd0c664e3de26e97e5571454,2016-03-25 22:59:23 +0000,http://www.flipkart.com/alisha-solid-women-s-c...,Alisha Solid Women's Cycling Shorts,"[""Clothing >> Women's Clothing >> Lingerie, Sl...",SRTEH2F6HUZMQ6SJ,699.0,267.0,"[""http://img5a.flixcart.com/image/short/6/2/h/...",False,Key Features of Alisha Solid Women's Cycling S...,No rating available,No rating available,Alisha,"{""product_specification""=>[{""key""=>""Number of ..."
4,bc940ea42ee6bef5ac7cea3fb5cfbee7,2016-03-25 22:59:23 +0000,http://www.flipkart.com/sicons-all-purpose-arn...,Sicons All Purpose Arnica Dog Shampoo,"[""Pet Supplies >> Grooming >> Skin & Coat Care...",PSOEH3ZYDMSYARJ5,220.0,210.0,"[""http://img5a.flixcart.com/image/pet-shampoo/...",False,Specifications of Sicons All Purpose Arnica Do...,No rating available,No rating available,Sicons,"{""product_specification""=>[{""key""=>""Pet Type"",..."


# Download Images

The original repository doesn't have the images, so it is needed to download them.

In [7]:
import os
import requests
from tqdm import tqdm
def download_image(url, folder, prefix_file_name):
    # if not os.path.exists(download_folder):
    #     os.mkdir(download_folder)
    file_name = f"{prefix_file_name}_{os.path.basename(url)}"
    file_path = os.path.join(folder, file_name)
    if os.path.exists(file_path):
        return file_path
    else: 
        response = requests.get(url)
        if response.status_code == 200:
            with open(file_path, 'wb') as f:
                f.write(response.content)
            return file_path
        else:
            return None

In [8]:
# transforming the string in an array os strings
import ast

def parse_string_to_array(s):
    try:
        return ast.literal_eval(s) if pd.notna(s) else []
    except (SyntaxError, ValueError):
        return []
        
df['image'] = df['image'].apply(parse_string_to_array)
#df['product_category_tree'] = df['product_category_tree'].apply(ast.literal_eval)

In [9]:
def extract_category(full_category):
    parse =  parse_string_to_array(full_category)
    if len(parse) > 0:
        return [cat.strip() for cat in parse[0].split(">>")][0]
    else:
        return None
        
df['category'] = df["product_category_tree"].apply(extract_category)

In [10]:
df.head()

Unnamed: 0,uniq_id,crawl_timestamp,product_url,product_name,product_category_tree,pid,retail_price,discounted_price,image,is_FK_Advantage_product,description,product_rating,overall_rating,brand,product_specifications,category
0,c2d766ca982eca8304150849735ffef9,2016-03-25 22:59:23 +0000,http://www.flipkart.com/alisha-solid-women-s-c...,Alisha Solid Women's Cycling Shorts,"[""Clothing >> Women's Clothing >> Lingerie, Sl...",SRTEH2FF9KEDEFGF,999.0,379.0,[http://img5a.flixcart.com/image/short/u/4/a/a...,False,Key Features of Alisha Solid Women's Cycling S...,No rating available,No rating available,Alisha,"{""product_specification""=>[{""key""=>""Number of ...",Clothing
1,7f7036a6d550aaa89d34c77bd39a5e48,2016-03-25 22:59:23 +0000,http://www.flipkart.com/fabhomedecor-fabric-do...,FabHomeDecor Fabric Double Sofa Bed,"[""Furniture >> Living Room Furniture >> Sofa B...",SBEEH3QGU7MFYJFY,32157.0,22646.0,[http://img6a.flixcart.com/image/sofa-bed/j/f/...,False,FabHomeDecor Fabric Double Sofa Bed (Finish Co...,No rating available,No rating available,FabHomeDecor,"{""product_specification""=>[{""key""=>""Installati...",Furniture
2,f449ec65dcbc041b6ae5e6a32717d01b,2016-03-25 22:59:23 +0000,http://www.flipkart.com/aw-bellies/p/itmeh4grg...,AW Bellies,"[""Footwear >> Women's Footwear >> Ballerinas >...",SHOEH4GRSUBJGZXE,999.0,499.0,[http://img5a.flixcart.com/image/shoe/7/z/z/re...,False,Key Features of AW Bellies Sandals Wedges Heel...,No rating available,No rating available,AW,"{""product_specification""=>[{""key""=>""Ideal For""...",Footwear
3,0973b37acd0c664e3de26e97e5571454,2016-03-25 22:59:23 +0000,http://www.flipkart.com/alisha-solid-women-s-c...,Alisha Solid Women's Cycling Shorts,"[""Clothing >> Women's Clothing >> Lingerie, Sl...",SRTEH2F6HUZMQ6SJ,699.0,267.0,[http://img5a.flixcart.com/image/short/6/2/h/a...,False,Key Features of Alisha Solid Women's Cycling S...,No rating available,No rating available,Alisha,"{""product_specification""=>[{""key""=>""Number of ...",Clothing
4,bc940ea42ee6bef5ac7cea3fb5cfbee7,2016-03-25 22:59:23 +0000,http://www.flipkart.com/sicons-all-purpose-arn...,Sicons All Purpose Arnica Dog Shampoo,"[""Pet Supplies >> Grooming >> Skin & Coat Care...",PSOEH3ZYDMSYARJ5,220.0,210.0,[http://img5a.flixcart.com/image/pet-shampoo/r...,False,Specifications of Sicons All Purpose Arnica Do...,No rating available,No rating available,Sicons,"{""product_specification""=>[{""key""=>""Pet Type"",...",Pet Supplies


In [29]:
%time
# Download imagens
skip = 0
load = 20000
errors = []
for index, row in tqdm(df[skip:load].iterrows(), total=df[skip:load].shape[0], desc=f'Reading DF'):
#     try:
    try:
        if len(row["image"])> 0: 
            filepath = download_image(
                row["image"][0],
                "/Users/samuel.matioli/work/customers/magalu/ecommerce-gen-ai/public/img/",
                row['uniq_id'])
            if filepath != None:
                df.at[index, "local_image"] = filepath
            input_img = f"{filepath}"
            input_text = f"{row['category']} {row['product_name']}"
            v_store.put(row_id=f"row_{index}", 
                        body_blob=input_text, 
                        vector=get_clip_embedding(input_text,input_img)[0], 
                        metadata= {
                            "filename":os.path.basename(filepath), 
                            "local_image": filepath } )
    except Exception as err:
        errors.append(index)
errors

CPU times: user 1 µs, sys: 1 µs, total: 2 µs
Wall time: 2.86 µs


Reading DF: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████| 20000/20000 [8:05:14<00:00,  1.46s/it]


[31,
 90,
 93,
 94,
 95,
 97,
 99,
 100,
 111,
 116,
 119,
 123,
 126,
 130,
 138,
 143,
 144,
 148,
 153,
 162,
 164,
 168,
 170,
 172,
 177,
 183,
 185,
 188,
 198,
 201,
 203,
 204,
 215,
 217,
 219,
 221,
 224,
 233,
 246,
 248,
 252,
 254,
 258,
 259,
 261,
 263,
 264,
 275,
 278,
 280,
 281,
 285,
 287,
 291,
 308,
 319,
 352,
 371,
 372,
 377,
 378,
 436,
 469,
 472,
 473,
 474,
 475,
 499,
 501,
 502,
 503,
 504,
 505,
 549,
 550,
 554,
 579,
 581,
 583,
 585,
 587,
 588,
 608,
 611,
 612,
 633,
 645,
 646,
 647,
 648,
 649,
 650,
 651,
 652,
 653,
 733,
 761,
 762,
 763,
 789,
 790,
 791,
 792,
 793,
 794,
 845,
 846,
 847,
 848,
 868,
 869,
 870,
 922,
 1069,
 1070,
 1071,
 1072,
 1073,
 1074,
 1142,
 1146,
 1147,
 1153,
 1271,
 1326,
 1335,
 1339,
 1345,
 1362,
 1826,
 1976,
 1980,
 1981,
 1982,
 1987,
 1994,
 1996,
 1998,
 2003,
 2007,
 2008,
 2010,
 2011,
 2012,
 2014,
 2016,
 2017,
 2020,
 2022,
 2024,
 2033,
 2037,
 2051,
 2052,
 2058,
 2062,
 2064,
 2069,
 2076,
 2080,


# Generating Embeddings

The embeddings will be generated by a "combination" of the image and text embeddings.

In [25]:
# Based on this paper 
# https://ai.meta.com/research/publications/scaling-autoregressive-multi-modal-models-pretraining-and-instruction-tuning/
def get_clip_embedding(text, image_path):    
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)        
    text = clip.tokenize(text,truncate=True).to(device)
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)    
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    averaged_features = (image_features + text_features) / 2    
    return averaged_features.numpy().tolist()

### Loading to DSE

Generate the embeddings for the combination of text + image.

The text to generate the embedding will consider only the first category and the product name.

In [None]:
%time
# Loading all flipcart data to the Vector Table
skip = 1000
load = 20000
for index, row in tqdm(df[skip:load].iterrows(), total=df[skip:load].shape[0], desc=f'Loading DF to DSE'):
    try:    
        if "local_image" in row.index and row["local_image"] != None: 
            input_img = f'{row["local_image"]}'
            input_text = f"{row['category']} {row['product_name']} "
            v_store.put(row_id=f"row_{index}", body_blob=input_text, 
                        vector=get_clip_embedding(input_text,input_img)[0], 
                        metadata= {"filename":os.path.basename(row["local_image"]), "local_image": row["local_image"]})
    except Exception as error:
        print(f"Error at IX {index} {error}")


print("Finished")

In [None]:
df.iloc[0]["local_image"]

In [None]:
# Function to convert the query in a embedding
def embed_query(q):
    query_embed = clip.tokenize(query,truncate=True).to(device)
    with torch.no_grad():    
        text_features = model.encode_text(query_embed)   
        text_features /= text_features.norm(dim=-1, keepdim=True)
    return text_features.numpy().tolist()[0]

In [None]:
# Function to conver the image in a embedding
def embed_image(image_path):
    image = transform(Image.open(image_path)).unsqueeze(0).to(device)    
    with torch.no_grad():
        image_features = model.encode_image(image)
        image_features /= image_features.norm(dim=-1, keepdim=True)
    return image_features.numpy().tolist()[0]

### Validating the similarity between text and image

In [None]:
query = "aluminum anolog watch for men"
text_embedding = embed_query(query)
print(f"Text embedding {text_embedding[:12]}...")

In [None]:
file_to_search = '998276640805c5268d5b5f7534ef05e2_bh23-luba-original-imae9mn7hpk7hwrt.jpeg'
image_search = f"/Users/samuel.matioli/work/gen-ai/multimodal/flipkart/images/{file_to_search}"
display(ShowImage(filename=image_search, width=200) )
image_embedding = embed_image(image_search)
print(f"Image embedding {image_embedding[:12]}...")

In [None]:
# Checking the similarity
from sklearn.metrics.pairwise import cosine_similarity
cosine_similarity_score = cosine_similarity(np.array([text_embedding]), np.array([image_embedding]))
print(f"Cosine similarity:", cosine_similarity_score[0][0])

# Searching using a description

In [None]:
query = "kids furniture"
results = v_store.ann_search(n=6, vector=embed_query(query))

print("="*40)
print(f"Search: {query}")
print("="*40)
print("Results:")
print("-"*40)
for r in results:
    display(ShowImage(filename=r['metadata']['image_url'], width=200) )
    print(r['body_blob'])
    print(r['metadata']['image_url'])


In [None]:
file_to_search = '998276640805c5268d5b5f7534ef05e2_bh23-luba-original-imae9mn7hpk7hwrt.jpeg'
image_search = f"/Users/samuel.matioli/work/gen-ai/multimodal/flipkart/images/{file_to_search}"
results = v_store.ann_search(n=5, vector=embed_image(image_search))

print("="*40)
print("Search:")
display(ShowImage(filename=image_search, width=200) )
print("="*40)
print("Results")
print("-"*40)

for r in results:
    display(ShowImage(filename=r['metadata']['image_url'], width=200) )
    print(r['body_blob'])
    print(r['metadata']['image_url'])

## Zero-Shot Classification

Reference: https://github.com/roboflow/notebooks/blob/main/notebooks/how-to-use-openai-clip-classification.ipynb

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg


In [None]:
categories = df['category'].unique()
classification_tokens = clip.tokenize(categories)

with torch.no_grad():
    text_features = model.encode_text(classification_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

In [None]:
#file_to_search = '946ec2dfec8ae7834e113c131c598d36_pbtmint4-mark-home-pbtmint4-original-imaefcqanczgzmmz.jpeg'
file_to_search = '998276640805c5268d5b5f7534ef05e2_bh23-luba-original-imae9mn7hpk7hwrt.jpeg'
image_path = f"/Users/samuel.matioli/work/gen-ai/multimodal/flipkart/images/{file_to_search}"
image_features = model.encode_image(image = transform(Image.open(image_path)).unsqueeze(0).to(device) )

text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

In [None]:
print("Classification")
print("="*40)
display(ShowImage(filename=image_path, width=200) )
print("="*40)

plt.grid()
y = np.arange(top_probs.shape[-1])
plt.barh(y, top_probs.detach().numpy()[0])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [categories[index] for index in top_labels[0].numpy()])
plt.xlabel("probability")
plt.show()
