<a href="https://colab.research.google.com/github/rajivgaba/semantic-spotter-langchain/blob/main/semantic_spotter_draft.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install required packages

! pip install -qU langchain-community pymupdf
! pip install -qU langchain_huggingface
! pip install -qU sentence-transformers
! pip install -qU langchain-perplexity
! pip install -qU faiss-cpu

In [None]:
# import libraries

import os, glob
from pathlib import Path
from langchain_community.document_loaders import PyMuPDFLoader, CSVLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from sentence_transformers import CrossEncoder, util
from langchain_core.prompts import ChatPromptTemplate
from langchain_perplexity import ChatPerplexity
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import CrossEncoderReranker
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
import gradio as gr
import kaggle
from kaggle.api.kaggle_api_extended import KaggleApi
from google.colab import userdata

In [None]:
# Set the key for LLM API

os.environ['PREPLEXITY_API_KEY'] = userdata.get('PREPLEXITY_API_KEY')

In [None]:
# Set base configuration values

config = {
    'data_path' : '/Users/RajivGaba/aiml_projects/Semantic Spotter/Data/',
    'chunk_size' : 512,
    'chunk_overlap' : 80,
    'vector_store_name' : "faiss_index",
    'embedding_model' : 'all-MiniLM-L6-v2',
    'refresh_vector_store' : 'Y',
    'PPLX_API_KEY' : os.getenv('PREPLEXITY_API_KEY'),
    'domain' : 'fashion',
    'chat_model' : "sonar-pro",
    'rerank_model' : 'BAAI/bge-reranker-base'
}

In [None]:
# Download dataset from kaggle using kaggle API

api = KaggleApi()
api.authenticate()
# api.dataset_download_files('promptcloud/myntra-e-commerce-product-data-november-2023', path='./data/', unzip=True)
api.dataset_download_files("ronakbokaria/myntra-products-dataset", path='./data/', unzip=True)

In [None]:
# Define reusable functions

def get_data_chunks(folder_path):
  """
  This function will create chunks from the files present in the dataset.
  This takes directory where the files exists. Basis the type of file i.e.
  PDF, CSV or text, a loader is initialised and chunks are created from
  content of the data
  """
  # loader = PyMuPDFLoader(pdf_file)
  # documents = loader.load()

  all_documents = []

  # Loop through all files in the folder
  for file_path in folder_path.iterdir():
      if file_path.suffix.lower() == ".pdf":
          loader = PyMuPDFLoader(str(file_path))
      elif file_path.suffix.lower() == ".csv":
          loader = CSVLoader(str(file_path))
      elif file_path.suffix.lower() == ".txt":
          loader = TextLoader(str(file_path))
      else:
          continue  # Skip unsupported file types

      # Load and append documents
      documents = loader.load()
      all_documents.extend(documents)

  # chunking/splitting
  text_splitter = RecursiveCharacterTextSplitter(
      chunk_size=config['chunk_size'],
      chunk_overlap=config['chunk_overlap'],
      strip_whitespace=True,
      separators=["\n\n", "\n", " ", ""]
  )
  text_chunks = text_splitter.split_documents(documents=all_documents)
  return text_chunks

def get_embeddings_model():
  embedding_model = HuggingFaceEmbeddings(
      model_name=config['embedding_model'],
      show_progress=True,
      multi_process=True,
      model_kwargs={'device': 'cuda'}
  )
  return embedding_model

def create_vector_store(text_chunks, embedding_model):
  if config['refresh_vector_store'] == 'Y' or not os.path.exists(config['vector_store_name']):
      vector_store = FAISS.from_documents(text_chunks, embedding_model)
      vector_store.save_local(config['vector_store_name'])
  else:
      vector_store = FAISS.load_local(config['vector_store_name'], embedding_model, allow_dangerous_deserialization=True)
  return vector_store

def get_cross_encoder_score(query, results):
  cross_encoder = CrossEncoder(config['cross_encoder_model'])
  for i, res in enumerate(results):
      ce_score = cross_encoder.predict([query, res.page_content])
      print(ce_score)

def create_chat_client():
  return ChatPerplexity(
      temperature=0,
      pplx_api_key=config['PPLX_API_KEY'],
      model=config['chat_model']
  )

def get_retriever(top_k=5):
  retriever = vector_store.as_retriever(search_kwargs={'k': top_k})
  return retriever

def get_llm_response(query, results, domain):
  system = """
  You are a helpful assistant in {domain} domain.
  Using the information contained in the context, give a comprehensive answer to the question.
  Respond only to the question asked, response should be concise and relevant to the question.
  Provide the number of the source document when relevant.
  If the answer cannot be deduced from the context, do not give an answer.
  If there are image files in the given context, you show the images to the user.

  #####
  Here is the context: {context}
  #####

  #####
  Here is chain of thought:
  User: dresses for beach vacation
  Assistant: Sure. Beach vacations are all about being free and close to nature. Here are suggestions: floral yellow shirt from brand X <image>, midi dress with fine blue prints from brand Y <image>
  User: how about dresses for women in yellow
  Assistant: Sure, here you go: midi dress from brand Y <image>, long floral gown from brand Z <image>
  #####

  You ask the follow up question to the user for relevant pairing suggestions like hat for beach, mufler for cold weather etc.
  """
  human = "{query}"

  prompt = ChatPromptTemplate.from_messages([("system", system), ("human", human)])

  chat = create_chat_client()

  chain = prompt | chat
  response = chain.invoke(
      {
          "context" : results,
          "domain" : config['domain'],
          "query": query,
      }
  )
  return response.content

def get_reranked_query_results(query):
  model = HuggingFaceCrossEncoder(model_name=config['rerank_model'])

  compressor = CrossEncoderReranker(model=model, top_n=5)

  compression_retriever = ContextualCompressionRetriever(
      base_compressor=compressor,
      base_retriever=get_retriever()
  )

  compressed_docs = compression_retriever.invoke(query)

  return compressed_docs

def rag_pipeline(user_question, history):

  # Retrieve relevant documents from the vector store based on the user's question.
  retrieved_documents = get_reranked_query_results(user_question)

  # Extract the page content from the retrieved documents and join them into a single context string.
  formatted_context = "\n\n".join(doc.page_content for doc in retrieved_documents)

  # Call the large language model with the user's question and the formatted context.
  answer = get_llm_response(user_question, formatted_context, config['domain'])
  return answer

def get_answer(user_query, history):
  # return the answer to user query. This function is useful in interaction with Gradio UI
  return rag_pipeline(user_query, history)

In [None]:
if __name__ == "__main__":
    chunked_data = []
    # folder_path = Path(config['data_path'])
    folder_path = Path("/content/data/")

    # Step 1: Generate chunks from the dataset
    if config['refresh_vector_store'] == 'Y':
        chunked_data = get_data_chunks(folder_path)

    # Step 2: Create embeddings and put them into a vector store
    embedding_model = get_embeddings_model()
    vector_store = create_vector_store(chunked_data, embedding_model)

In [None]:
get_reranked_query_results("party dresses for women")

In [None]:
iface = gr.ChatInterface(
    get_answer,
    type = 'messages',
    chatbot=gr.Chatbot(height=500),
    textbox=gr.Textbox(placeholder="Help me find a formal shirt..", container=False, scale=7),
    title="ClosetAI - Fashion Studio",
    description="Ask ClosetAI anything about fashion products on Mytra",
    theme="glass",
    # examples=["Blazers for men", "Party dresses for women", "Athleisure" , "Sports attire" ],
    cache_examples=True,
)
iface.launch(share=True, debug=False)

In [None]:
iface.close()