In [106]:
# Setup the Chroma Storage
from langchain_community.vectorstores import Chroma
from langchain_experimental.open_clip import OpenCLIPEmbeddings
# Create Chroma
vectorstore = Chroma(
    collection_name="ciam_rag_data", 
    embedding_function=OpenCLIPEmbeddings(model=None, preprocess=None, tokenizer=None)
)


In [107]:
from pathlib import Path
import json

product_path = "../training-sets/MEN_FASHION_WITH_REVIEWS"
product_list = json.loads(Path(product_path + "/raw_products.json").read_text())
products = product_list[0:1]

In [108]:
for prod in products:
    keys = [
        "reviews",
        "category",
        "price",
        "currency",
        "total_customers_that_rated",
        "overall_ratings",
        "description",
        "name",
        "id",
        "description",
        "product_asin"
    ]
    product_asin = prod['product_asin']
    print(Path(f"{product_path}/images/{product_asin}.png").exists())

True


In [121]:
# Prepare the Document for each product
from langchain_core.documents import Document

def convert_to_document(product: dict):
    product_asin = product.get('product_asin').strip()
    product_name = product.get('name').strip()
    price_currency = product.get('currency')
    product_description = product.get('description').strip()
    image_path = str(Path(f"{product_path}/images/{product_asin}.png").absolute())
    document = Document(
        id=product_asin,
        page_content=f"Product Asin: {product_asin} \n\n Name: {product_name} \n\n Description: {product_description}",
        metadata={
            "category": product.get("category"),
            "product_name": product_name,
            "price": f"{price_currency}{product.get('price')}",
            "overall_ratings": product.get("overall_ratings"),
            "product_asin": product.get("product_asin"),
            "total_customers_that_rated": product.get('total_customers_that_rated'),
            "image_path": image_path,
        }
    )
    return document


documents = list(map(convert_to_document, product_list[0:5]))


In [128]:
# Add our data into the Vector Store
doc_ids = vectorstore.add_documents(documents, ids=[doc.id for doc in documents])

document_images = [
    {
        "url": doc.metadata.get("image_path"),
        "product_asin": doc.metadata.get("product_asin"),
        "product_name": doc.metadata.get("name"),
    } for doc in documents
]

image_doc_ids = vectorstore.add_images( uris=[ image['url'] for image in document_images ], metadata=document_images, ids=[ image_data['product_asin'] for image_data in document_images ] )
assert doc_ids == image_doc_ids

result = vectorstore.search("sneakers", "mmr")
print(result)

retriever = vectorstore.as_retriever()

[Document(metadata={'category': 'Fashion', 'image_path': '/home/solomon/Documents/projects/ciam2rag/ciam2rag_core/notebooks/../training-sets/MEN_FASHION_WITH_REVIEWS/images/B0C812K6RR.png', 'overall_ratings': 4.3, 'price': '£16.99', 'product_asin': 'B0C812K6RR', 'total_customers_that_rated': 156}, page_content="Product Asin: B0C812K6RR \n\n Name: Safety Trainers Men Women Steel Toe Cap Trainers Safety Shoes Lightweight Work Shoes Safety Boots Industrial Protective Shoes \n\n Description: Product details     Care instructions     Hand Wash Only       Sole material     Thermoplastic Elastomers       Outer material     Textile       Inner material     Textile      About this item   Breathable upper: Fashion appearance, mesh hole design very breathable, these work boots are extremely light, while the breathable mesh upper can effectively disperse heat, even in hot working conditions, your feet will stay fresh cool and odorless.   Anti-Smashing: The standard widened steel toe is able absorb

In [111]:
# Utility Functions from Langchain cookbooks

import base64
import io
from io import BytesIO

import numpy as np
from PIL import Image


def resize_base63_image(base64_string, size=(128, 128)):
    """
    Resize an image encoded as a Base64 string.
    
    Args:
    base64_string (str): Base64 string of the original image size (tuple): Desired size of the image as (width, height).
    
    Returns:
    str: Base64 string of the resized image.
    
    """
    
    # Decode the Base64 string
    img_data = base64.b64decode(base64_string)
    img = Image.open(io.BytesIO(img_data))
    
    # Resize the image
    resized_img = img.resize(size, Image.Resampling.LANCZOS)
    
    # Save the resized image to a bytes buffer
    buffered = io.BytesIO()
    resized_img.save(buffered, format=img.format)
    
    # Encode the resized image to Base64
    return base64.b64encode(buffered.getvalue()).decode("utf-8")

def is_base64(s):
    """Check if a string is Base64 encoded"""
    
    try:
        return base64.b64encode(base64.b64decode(s)) == s.encode()
    except Exception:
        return False
    
def split_image_text_types(docs):
    """Split numpy arrays images and texts"""
    
    images = []
    texts = []
    
    for doc in docs:
        converted_doc = doc.page_content # Extract Document contents
        if is_base64(converted_doc):
            # Resize image to avoid OAI server error
            images.append(
                resize_base63_image(converted_doc, size=(250, 250))
            ) # base64 encoded str
        else:
            texts.append(doc)
            
    return {"images": images, "texts": texts}

In [125]:
from IPython.display import HTML, display


def plt_img_base64(img_base64):
    # Create an HTML img tag with the base64 string as the source
    image_html = f'<img src="data:image/jpeg;base64,{img_base64}" />'
    
    # Display the image by rendering the HTML
    display(HTML(image_html))
    
    
docs = retriever.invoke("shoe", k=10)
for doc in docs:
    if is_base64(doc.page_content):
        plt_img_base64(doc.page_content)
    else:
        print(doc.metadata)

{'category': 'Fashion', 'image_path': '/home/solomon/Documents/projects/ciam2rag/ciam2rag_core/notebooks/../training-sets/MEN_FASHION_WITH_REVIEWS/images/B01A6LTSUK.png', 'overall_ratings': 4.7, 'price': '£38.49', 'product_asin': 'B01A6LTSUK', 'total_customers_that_rated': 102877}
{'category': 'Fashion', 'image_path': '/home/solomon/Documents/projects/ciam2rag/ciam2rag_core/notebooks/../training-sets/MEN_FASHION_WITH_REVIEWS/images/B01A6LTSUK.png', 'overall_ratings': 4.7, 'price': '£38.49', 'product_asin': 'B01A6LTSUK', 'total_customers_that_rated': 102877}
{'category': 'Fashion', 'image_path': '/home/solomon/Documents/projects/ciam2rag/ciam2rag_core/notebooks/../training-sets/MEN_FASHION_WITH_REVIEWS/images/B01A6LTSUK.png', 'overall_ratings': 4.7, 'price': '£38.49', 'product_asin': 'B01A6LTSUK', 'total_customers_that_rated': 102877}
{'category': 'Fashion', 'image_path': '/home/solomon/Documents/projects/ciam2rag/ciam2rag_core/notebooks/../training-sets/MEN_FASHION_WITH_REVIEWS/images/

In [113]:


from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_community.chat_models.ollama import ChatOllama
from langchain_core.prompts.chat import ChatPromptTemplate, MessagesPlaceholder


VISION_MODEL_NAME = "gpt-4o"
# VISION_MODEL_NAME = "llava"
VISION_MODEL_NAME = "llama3"

OLLAMA_IN_USE = True

In [114]:

def prompt_func(data_dict: dict):
    # Joining the context texts into a single string
    
    formatted_texts = "\n".join(data_dict["context"]["texts"])
    messages = []
    # Adding image(s) to the messages if present
    if data_dict["context"]["images"]:
        print("Adding images to the messages")
        image_message = {
            "type": "image_url",
            "image_url": {
                "url": f"{data_dict['context']['images'][0]}" if OLLAMA_IN_USE == "llava" else f"data:image/jpeg;base64, {data_dict['context']['images'][0]}"
            },
        }
        messages.append(image_message)
    
    # Adding the text message for analysis
    text_message = {
        "type": "text",
        "text": (
            "You answer should be in the format "
            "Product ASIN\nProduct Name\nYour Thought\nPrice\nOther Details"
            "You work in a fashion store here in the UK. Your task as an intelligent customer "
            "assistant is to answer the customer's query "
            "You should also look at reviews to back up your answers."
            "The following are the products we pulled from the store that might "
            "match their query, use this to answer their question: {context}"
            f"User Query: {data_dict['question']}\n\n"
            "Text and / or tables:\n"
            f"{formatted_texts}"
        )
    }
    messages.append(text_message)
    
    return [HumanMessage(content=messages)]



In [115]:
if OLLAMA_IN_USE:
    from langchain_community.chat_models.ollama import ChatOllama
    model = ChatOllama(model=VISION_MODEL_NAME, temperature=0, 
                    num_ctx=4096, ## Max Tokens: 4096 for Ollama local models please
            )
else:
    model = ChatOpenAI(temperature=0, model="gpt-4o", max_tokens=1024)

# RAG pipeline
chain = (
    { "context": retriever | RunnableLambda(split_image_text_types),
        "question": RunnablePassthrough()
     }
    | RunnableLambda(prompt_func)
    | model
    | StrOutputParser()
)

In [116]:
chain.invoke("What is the best material to wear during winter?")

TypeError: sequence item 0: expected str instance, Document found

In [None]:

from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain

def get_conversation_chain():
    """Returns a chain that asks the model to generate an answer based on the given context and user summarised conversation.

    Args:
        vector_store (Chroma): The vectorstore that contains the documents

    Returns:
        Runnable: A LargeChain Runnable
    """

    # llm = ChatOpenAI(temperature=0, model="gpt-4o", max_tokens=500)
    llm = ChatOllama(temperature=0, model=VISION_MODEL_NAME)

    prompt = ChatPromptTemplate.from_messages(
        [
            ("system", CUSTOMER_SUPPORT_EXPERT_INTRODUCTORY_INSTRUCTION),
            MessagesPlaceholder(variable_name="chat_history"),
            ("user", "{input}"),
        ]
    )

    stuff_documents_chain = create_stuff_documents_chain(llm=llm, prompt=prompt)

    return create_retrieval_chain(
        retriever=get_retrieval_chain(vector_store=vector_store),
        combine_docs_chain=stuff_documents_chain,
    )



def get_retrieval_chain(retriever):
    """Retrieves the document that matches the query of the user. The function uses the `create_history_aware_retriever`
    function from LangChain to collate the conversations we have been having with this user, summarises it and use the summary
    to find relevant documents in the vector store.

    Args:
        vector_store (Chroma): The vector that contains the documents.

    Returns:
        Runnable: A LargeChain Runnable
    """

    llm = ChatOllama(temperature=0, model="llama3")

    # Using the normal retriever technique, as our use case is simple.
    # Using other retriever could cause slight delay in responding to the
    # user as they will need to use LLM to perform various tasks.
    # Our use-case is simple and using this retriever works most (if not all) the time.
    # retriever

    prompt = ChatPromptTemplate.from_messages(
        [
            MessagesPlaceholder(variable_name="chat_history"),
            SystemMessage(content="As an expert in a Fashion Store, given the above conversation, generate a search query to look up in order to get information relevant to the  conversation"),
            ("user", "{input}"),
        ]
    )

    return create_history_aware_retriever(llm=llm, retriever=retriever, prompt=prompt)

In [None]:
conversation_chain = get_conversation_chain()
assistant_response = conversation_chain.invoke(
    {"input": "I need a sneakers", "chat_history": [] }
)