# Multimodal RAG with the neo4j-genai package

```shell
NEO4J_URI=bolt://localhost:7687
NEO4J_USER=neo4j
NEO4J_PASSWORD=password
NEO4J_DATABASE=neo4j
```

In [None]:
import os
from dotenv import load_dotenv

import neo4j
import ollama

from PIL import Image
from sentence_transformers import SentenceTransformer

from neo4j_genai.retrievers import VectorRetriever, VectorCypherRetriever
from neo4j_genai.types import RetrieverResultItem
from neo4j_genai.embeddings import SentenceTransformerEmbeddings
from neo4j_genai.llm import LLMInterface
from neo4j_genai.llm.types import LLMResponse
from neo4j_genai.generation import GraphRAG

from IPython.display import Image as IPythonImage

In [None]:
load_dotenv()
NEO4J_URI = os.getenv("NEO4J_URI", "bolt://localhost:7687")
NEO4J_USER = os.getenv("NEO4J_USER", "neo4j")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD", "password")
NEO4J_DATABASE = os.getenv("NEO4J_DATABASE", "neo4j")

POSTER_INDEX_NAME = "moviePostersEmbedding"
IMAGE_EMBEDDING_MODEL = "clip-ViT-B-32"

In [None]:
driver = neo4j.GraphDatabase.driver(
    NEO4J_URI,
    auth=(NEO4J_USER, NEO4J_PASSWORD),
    database=NEO4J_DATABASE,
)

## Search similar images

In [None]:
model = SentenceTransformer(IMAGE_EMBEDDING_MODEL)

In [None]:
image_path = "./images/Notre-Dame_de_Paris_2013-07-24.jpg"
vector = model.encode(Image.open(image_path)).tolist()
vector[:5]

In [None]:
retriever = VectorRetriever(
    driver,
    index_name=POSTER_INDEX_NAME,
)

In [None]:
result = retriever.search(query_vector=vector, top_k=4)
for r in result.items:
    print(r)

In [None]:
retriever = VectorCypherRetriever(  # NEW
    driver,
    index_name=POSTER_INDEX_NAME,
    retrieval_query="RETURN node.title as title, node.plot as plot, node.poster as posterUrl, score",  # NEW
)

In [None]:
result = retriever.search(query_vector=vector, top_k=4)
for r in result.items:
    print(r)

In [None]:
def format_record_function(record: neo4j.Record) -> RetrieverResultItem:
    return RetrieverResultItem(
        content=f"Movie title: {record.get('title')}, movie plot: {record.get('plot')}",
        metadata={
            "title": record.get('title'),
            "plot": record.get("plot"),
            "poster": record.get("posterUrl"),
            "score": record.get("score"),
        }
    )


retriever = VectorCypherRetriever(
    driver,
    index_name=POSTER_INDEX_NAME,
    retrieval_query="RETURN node.title as title, node.plot as plot, node.poster as posterUrl, score",
    format_record_function=format_record_function,  # NEW
)

result = retriever.search(query_vector=vector, top_k=4)
for r in result.items:
    print(r.content, r.metadata["score"])
    display(IPythonImage(url=r.metadata["poster"]))

## Search Images from its content

In [None]:
query_text = "Find a movie taking place in Paris and explain the plot."
top_k = 3

In [None]:
retriever = VectorCypherRetriever(
    driver,
    index_name=POSTER_INDEX_NAME,
    retrieval_query="RETURN node.title as title, node.plot as plot, node.poster as posterUrl, score",
    embedder=SentenceTransformerEmbeddings(IMAGE_EMBEDDING_MODEL),  # NEW
    format_record_function=format_record_function,
)


In [None]:
result = retriever.search(query_text=query_text, top_k=top_k)

In [None]:
for r in result.items:
    print(r.content, r.metadata.get("score"))
    display(IPythonImage(url=r.metadata["poster"]))

## RAG by searching on images

In [None]:
from langchain_community.chat_models import ChatOllama
llm = ChatOllama(model="llama3:8b")
rag = GraphRAG(retriever=retriever, llm=llm)
rag_result = rag.search(
    "Find a movie taking place in Paris and explain the plot.", 
    retriever_config={"top_k": top_k},
)
print(rag_result.answer)

In [None]:
class OllamaLLM(LLMInterface):

    def invoke(self, input: str) -> LLMResponse:
        response = ollama.chat(model=self.model_name, messages=[
          {
            'role': 'user',
            'content': input,
          },
        ])
        return LLMResponse(
            content=response["message"]["content"]
        )

In [None]:
rag = GraphRAG(
    retriever=retriever,
    llm=OllamaLLM('llama3:8b')
)

In [None]:
rag_result = rag.search(
    # "Find a movie with astronauts and explain the plot.", 
    "Find a movie taking place in Paris and explain the plot.", 
    retriever_config={"top_k": top_k},
    return_context=True,
)
print(rag_result.answer)

In [None]:
[r.content for r in rag_result.retriever_result.items]