<a href="https://colab.research.google.com/github/rajivgaba/semantic-spotter-langchain/blob/main/semantic_spotter_final.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## ✨ Problem Statement ✨

*Fashion Search AI* 👗👖👕👟👠 : Create a generative search system capable of searching a plethora of product descriptions to find and recommend appropriate choices against a user query.

**Author:** Rajiv Gaba <br>
***LinkedIn:*** https://www.linkedin.com/in/rajiv-gaba/ <br>
***GitHub:*** https://github.com/rajivgaba <br><br>

### Install Dependencies

This cell installs all the necessary Python packages using `pip`. The `-qU` flags ensure a quiet installation and upgrade of packages. The required libraries include:

- `langchain-community`: Provides integrations with various external resources.
- `langchain_huggingface`: Enables the use of Hugging Face models within LangChain.
- `sentence-transformers`: For generating sentence embeddings.
- `langchain-perplexity`: Integration with the Perplexity AI API.
- `faiss-cpu`: A library for efficient similarity search and clustering of dense vectors.
- `gradio`: For creating a user interface for the application.
- `kaggle`: To interact with the Kaggle API for dataset download.

In [None]:
# install required packages

! pip install -qU langchain-community
! pip install -qU langchain_huggingface
! pip install -qU sentence-transformers
! pip install -qU langchain-perplexity
! pip install -qU faiss-cpu
! pip install -qU gradio
! pip install -qU kaggle

### Filter Warnings

This cell imports the `warnings` module and sets a filter to ignore all warnings. This is often done to keep the output clean, but it's important to be aware that this might hide potentially useful information about issues in the code.

In [None]:
# Filter warnings

import warnings
warnings.filterwarnings('ignore')

### Import Libraries

This cell imports all the necessary libraries for the RAG pipeline and Gradio interface. These include:

- `os`, `glob`, `Path`: For interacting with the file system.
- `PyMuPDFLoader`, `CSVLoader`, `TextLoader`: Loaders from `langchain-community` for reading different file types.
- `RecursiveCharacterTextSplitter`: For splitting text into smaller chunks.
- `HuggingFaceEmbeddings`: For generating embeddings using Hugging Face models.
- `FAISS`: For creating and loading a FAISS vector store.
- `CrossEncoder`, `util`: From `sentence_transformers` for potential use in reranking (although `HuggingFaceCrossEncoder` is used later).
- `ChatPromptTemplate`, `MessagesPlaceholder`: For creating chat prompts.
- `ChatPerplexity`: For interacting with the Perplexity AI chat model.
- `ContextualCompressionRetriever`: For compressing and reranking retrieved documents.
- `ConversationBufferMemory`: For managing conversation history.
- `CrossEncoderReranker`: For reranking documents using a cross-encoder model.
- `HuggingFaceCrossEncoder`: A specific cross-encoder implementation from Hugging Face.
- `HumanMessage`, `AIMessage`: For representing messages in a conversation.
- `gradio`: For building the user interface.
- `PIL.Image`: For handling images.

In [None]:
# import libraries

import os, glob
from pathlib import Path
from langchain_community.document_loaders import PyMuPDFLoader, CSVLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from sentence_transformers import CrossEncoder, util
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_perplexity import ChatPerplexity
from langchain.retrievers import ContextualCompressionRetriever
from langchain.memory import ConversationBufferMemory
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
from langchain_core.messages import HumanMessage, AIMessage
import gradio as gr
from PIL import Image

2025-10-01 14:39:59.666540: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759329600.042728      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759329600.146096      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


### Set LLM Keys

This cell sets the API key for the Perplexity AI model based on the execution environment (Kaggle, Colab, or Local). It retrieves the key from environment variables or secrets managers.

In [None]:
# Set LLM keys from secret manager / local environment

if os.path.exists('/kaggle'):
    platform = 'Kaggle'
    from kaggle_secrets import UserSecretsClient
    user_secrets = UserSecretsClient()
    print(f"Using kaggle secrets to get keys")
    os.environ['PERPLEXITY_API_KEY']  = user_secrets.get_secret("PPLX_API_KEY_2")
elif os.path.exists('/content'):
    platform = 'Colab'
    from google.colab import userdata
    print(f"Using Google colab secrets to get keys")
    os.environ['PERPLEXITY_API_KEY'] = userdata.get('PPLX_API_KEY_2')
else:
    platform = 'Local'
    import dotenv
    dotenv.load_dotenv()
    print(f"Using local env secrets to get keys")
    os.environ['PERPLEXITY_API_KEY'] = os.getenv('PPLX_API_KEY')

Using kaggle secrets to get keys


### Define Configuration

This cell defines a Python dictionary `config` that holds various configuration parameters for the application, such as data paths, chunk sizes, model names, and API keys.

In [None]:
# Define a base configuration

config = {
    'data_path' : '/kaggle/input/myntra-fashion-product-dataset/',
    'images_path' : '/kaggle/input/myntra-fashion-product-dataset/images/',
    'chunk_size' : 512,
    'chunk_overlap' : 80,
    'vector_store_name' : "faiss_myntra_db",
    'embedding_model' : 'all-MiniLM-L6-v2',
    'refresh_vector_store' : 'N',
    'PPLX_API_KEY' : os.getenv('PERPLEXITY_API_KEY'),
    'domain' : 'fashion',
    'chat_model' : "sonar-pro",
    'rerank_model' : 'BAAI/bge-reranker-base',
    'platform' : platform
}

### Download Dataset

This cell downloads the Myntra fashion product dataset from Kaggle using the Kaggle API if the notebook is not being run on the Kaggle platform.

In [None]:
# Download dataset if not running on kaggle notebook

if config['platform'] != 'Kaggle':
    # Download dataset from kaggle using kaggle API
    from kaggle.api.kaggle_api_extended import KaggleApi

    api = KaggleApi()
    api.authenticate()
    # api.dataset_download_files('promptcloud/myntra-e-commerce-product-data-november-2023', path='./data/', unzip=True)
    # api.dataset_download_files("ronakbokaria/myntra-products-dataset", path='./data/', unzip=True)
    api.dataset_download_files("djagatiya/myntra-fashion-product-dataset", path='./data/', unzip=True)

### Add Metadata to Documents

This function `add_metadata_to_documents` takes a list of documents and extracts specific information (image URL, product ID, name, category, price, color, brand, rating count, average rating) from the document content to add as metadata. This enriched metadata can be useful for filtering and displaying information later.

In [None]:
# Define a function that will enhance the metadata of the documents

def add_metadata_to_documents(documents):
    """
    Adds image URL and other attributes from page content to the metadata of each document.
    """
    for doc in documents:
        try:
            image_url = doc.page_content.split("\n")[6].split(' ')[1]
            doc.metadata['image_url'] = image_url
            pid = doc.page_content.split('\n')[0].split()[1]
            doc.metadata['image_path_local'] = config['images_path'] + pid + ".jpg"
            doc.metadata['p_id'] = pid
            doc.metadata['product_name'] = doc.page_content.split("\n")[1].split(' ')[1]
            doc.metadata['product_category'] = doc.page_content.split("\n")[2].split(' ')[1]
            doc.metadata['price'] = doc.page_content.split("\n")[3].split(' ')[1]
            doc.metadata['color'] = doc.page_content.split("\n")[4].split(' ')[1]
            doc.metadata['brand'] = doc.page_content.split("\n")[5].split(' ')[1]
            doc.metadata['rating_count'] = doc.page_content.split("\n")[7].split(' ')[1]
            doc.metadata['avg_rating'] = doc.page_content.split("\n")[8].split(' ')[1]
        except (IndexError, AttributeError):
            # Handle cases where the image URL might not be present or in a different format
            doc.metadata['image_url'] = None
            doc.metadata['image_path_local'] = None
            doc.metadata['p_id'] = None
            doc.metadata['product_name'] = None
            doc.metadata['product_category'] = None
            doc.metadata['price'] = None
            doc.metadata['price'] = None
            doc.metadata['brand'] = None
            doc.metadata['rating_count'] = None
            doc.metadata['avg_rating'] = None
    return documents

### Get Data Chunks

This function `get_data_chunks` reads files from a specified folder path, loads their content using appropriate loaders (PDF, CSV, or Text), adds metadata using the `add_metadata_to_documents` function, and then splits the documents into smaller chunks using `RecursiveCharacterTextSplitter`.

In [None]:
def get_data_chunks(folder_path):
  """
  This function will create chunks from the files present in the dataset.
  This takes directory where the files exists. Basis the type of file i.e.
  PDF, CSV or text, a loader is initialised and chunks are created from
  content of the data
  """
  # loader = PyMuPDFLoader(pdf_file)
  # documents = loader.load()

  all_documents = []

  # Loop through all files in the folder
  for file_path in folder_path.iterdir():
      if file_path.suffix.lower() == ".pdf":
          loader = PyMuPDFLoader(str(file_path))
      elif file_path.suffix.lower() == ".csv":
          loader = CSVLoader(str(file_path))
      elif file_path.suffix.lower() == ".txt":
          loader = TextLoader(str(file_path))
      else:
          continue  # Skip unsupported file types

      # Load and append documents
      documents = loader.load()
      documents_with_metadata = add_metadata_to_documents(documents)
      all_documents.extend(documents_with_metadata)

  # chunking/splitting
  text_splitter = RecursiveCharacterTextSplitter(
      chunk_size=config['chunk_size'],
      chunk_overlap=config['chunk_overlap'],
      strip_whitespace=True,
      separators=["\n\n", "\n", " ", ""]
  )
  text_chunks = text_splitter.split_documents(documents=documents_with_metadata)
  return text_chunks

### Get Embeddings Model

This function `get_embeddings_model` initializes and returns a `HuggingFaceEmbeddings` model with the specified model name and device. Embeddings are numerical representations of text that capture semantic meaning.

In [None]:
def get_embeddings_model():
  embedding_model = HuggingFaceEmbeddings(
      model_name=config['embedding_model'],
      show_progress=True,
      multi_process=True,
      model_kwargs={'device': 'cuda'}
  )
  return embedding_model

### Create Vector Store

This function `create_vector_store` creates or loads a FAISS vector store. If `refresh_vector_store` is 'Y' or the vector store doesn't exist, it creates a new one from the provided text chunks and embedding model and saves it locally. Otherwise, it loads an existing vector store. FAISS allows for efficient similarity search on the embeddings.

In [None]:
def create_vector_store(text_chunks, embedding_model):
  if config['refresh_vector_store'] == 'Y' or not os.path.exists(config['vector_store_name']):
      vector_store = FAISS.from_documents(text_chunks, embedding_model)
      vector_store.save_local(config['vector_store_name'])
  else:
      vector_store = FAISS.load_local(config['vector_store_name'], embedding_model, allow_dangerous_deserialization=True)
  return vector_store

### Create Chat Client

This function `create_chat_client` initializes and returns a `ChatPerplexity` object, which is used to interact with the Perplexity AI chat model. It sets the temperature, API key, and model name from the configuration.

In [None]:
def create_chat_client():
  return ChatPerplexity(
      temperature=0,
      pplx_api_key=config['PPLX_API_KEY'], # Pass the API key explicitly
      model=config['chat_model']
  )

### Get Retriever

This function `get_retriever` creates and returns a retriever from the FAISS vector store. The retriever is used to fetch relevant documents based on a user query. The `search_kwargs={'k': top_k}` specifies the number of top documents to retrieve.

In [None]:
def get_retriever(top_k=10):
  retriever = vector_store.as_retriever(search_kwargs={'k': top_k})
  return retriever

### Get Reranked Query Results

This function `get_reranked_query_results` takes a user query, retrieves an initial set of documents using the retriever, and then uses a `HuggingFaceCrossEncoder` for reranking the retrieved documents. Reranking helps to improve the relevance of the retrieved documents by considering the query and document content together. The function returns the top `top_n` (set in the compressor) reranked documents.

In [None]:
def get_reranked_query_results(query):
    model = HuggingFaceCrossEncoder(model_name=config['rerank_model'])
    compressor = CrossEncoderReranker(model=model, top_n=3)
    compression_retriever = ContextualCompressionRetriever(
      base_compressor=compressor,
      base_retriever=get_retriever()
    )
    compressed_docs = compression_retriever.invoke(query)
    print("="*80)
    print(f"{compressed_docs}")
    print("="*80)

    return compressed_docs

### Generate LLM Response

This function `generate_llm_response` takes the user query and the reranked search results as input. It creates a `ChatPerplexity` client, defines a system message that provides context to the LLM based on the retrieved documents, and then uses a `ChatPromptTemplate` to format the prompt for the LLM. Finally, it invokes the LLM with the formatted prompt and returns the generated response.

In [None]:
def generate_llm_response(query, results):
    llm = create_chat_client()
    #if system_message is None:
    system_message = f"""
    You are a helpful AI assistant in fashion domain and expert in looking into given documents and find relevant products.
    Do not give any product listing outside this context.
    Context is give here:
    #####
    {results}
    #####
    If you don't know the answer, say so. Keep the conversation flowing.

    #####
    You extract brand, price, avg_rating, rating_count, image_url and image_path_local from the metadata
    #####
    """
    prompt_template = ChatPromptTemplate.from_messages([
        ("system", "{system_message}"),
        ("human", "{query}")
    ])

    chain = prompt_template | llm
    llm_response = chain.invoke(
        {"query" : query,
        "system_message" : system_message}
    )
    return llm_response

### RAG Pipeline

This function `rag_pipeline` orchestrates the RAG (Retrieval Augmented Generation) process. It takes user input, retrieves and reranks relevant documents using `get_reranked_query_results`, formats the retrieved documents as context, and then generates a response using `generate_llm_response`.

In [None]:
def rag_pipeline(user_input):
    # Get reranked top results of the user query
    retrieved_documents = get_reranked_query_results(user_input)
    # formatted_context = "\n\n".join(doc.page_content for doc in retrieved_documents)
    formatted_context = "\n\n".join( (str(doc.metadata) + doc.page_content) for doc in retrieved_documents)
    print("*"*80)
    print(f"\n\n\n {formatted_context} \n\n\n")
    print("*"*80)
    answer = generate_llm_response(user_input, formatted_context)
    return answer

### Get Answer

This function `get_answer` is a simple wrapper around the `rag_pipeline` function, taking a question as input and returning the final answer generated by the RAG pipeline.

In [None]:
def get_answer(question):
    # Call the RAG pipeline to get the answer based on the user's question.
    final_answer = rag_pipeline(question)
    return final_answer

### Start Process and Set Path

This cell sets the `folder_path` variable based on the platform the notebook is running on (Kaggle, Colab, or Local). This path is used to locate the dataset.

In [None]:
# Start the process and set path for dataset

chunked_data = []

if (config['platform']).lower() == 'kaggle':
    folder_path = Path("/kaggle/input/myntra-fashion-product-dataset")
elif (config['platform']).lower() == 'colab':
    folder_path = Path("/content/data/")
else:
    folder_path = Path(config['data_path'])

print(folder_path)

/kaggle/input/myntra-fashion-product-dataset


### Create Vector Store

This cell either generates chunks from the dataset and creates a new FAISS vector store or loads an existing one based on the `refresh_vector_store` configuration. This step prepares the vector store for efficient document retrieval.

In [None]:
if config['refresh_vector_store'] == 'Y':
    # Generate chunks from the dataset
    chunked_data = get_data_chunks(folder_path)

    # Create embeddings and put them into a vector store
    embedding_model = get_embeddings_model()
    vector_store = create_vector_store(chunked_data, embedding_model)
else:
    vector_store = FAISS.load_local(
        config['vector_store_name'],
        get_embeddings_model(),
        allow_dangerous_deserialization=True
    )

modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

### User Input Example

This cell sets an example `user_input` variable, which can be used for testing the RAG pipeline or the Gradio interface.

In [None]:
user_input = "party dresses for women"

### Gradio Chat Interface

This function `gradio_chat_interface` is designed to be used with Gradio. It takes user input, calls the `rag_pipeline` to get the LLM response, cleans the response text, and attempts to retrieve and display product images from the metadata of the top 3 retrieved documents.

In [None]:
import html

def gradio_chat_interface(user_input):
    results = get_reranked_query_results(user_input)
    response = generate_llm_response(user_input, results)

    # Clean HTML entities
    cleaned_text = html.unescape(response.content.replace('"', '"').replace('&', '&'))

    # Get product images
    images = []
    for doc in results[:3]:
        image_path = doc.metadata.get('image_path_local')
        if image_path and os.path.exists(image_path):
            try:
                images.append(Image.open(image_path))
            except:
                images.append(None)
        else:
            images.append(None)

    while len(images) < 3:
        images.append(None)

    return cleaned_text, images[0], images[1], images[2]



### Create Gradio Interface

This cell defines and launches the Gradio user interface. It creates a simple interface with a textbox for user input, a button to trigger the search, a textbox to display the LLM's recommendations, and image components to display product images. The `gradio_chat_interface` function is linked to the button click and textbox submission events.

In [None]:
# Create Gradio interface
with gr.Blocks(title="Semantic Spotter") as demo:
    gr.Markdown("# 🛍️ Semantic Spotter - Fashion Search")

    with gr.Row():
        user_input = gr.Textbox(label="Search Query", placeholder="party dresses for women", scale=3)
        submit_btn = gr.Button("Search", variant="primary", scale=1)

    response_text = gr.Textbox(label="Recommendations", lines=6, interactive=False)

    with gr.Row():
        img1 = gr.Image(label="Product 1", height=200)
        img2 = gr.Image(label="Product 2", height=200)
        img3 = gr.Image(label="Product 3", height=200)

    submit_btn.click(gradio_chat_interface, [user_input], [response_text, img1, img2, img3])
    user_input.submit(gradio_chat_interface, [user_input], [response_text, img1, img2, img3])

demo.launch(share=True)

* Running on local URL:  http://127.0.0.1:7860
* Running on public URL: https://9c1dfbf1c13f5580a6.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


