In [None]:
import uuid
from langchain.vectorstores import Chroma
from langchain.storage import InMemoryStore
from langchain.schema.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.retrievers.multi_vector import MultiVectorRetriever

# Use HuggingFace embeddings (free, local or cached)
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")

# The vectorstore to use to index the child chunks
vectorstore = Chroma(
    collection_name="multi_modal_rag",
    embedding_function=embedding_model
)

# The storage layer for the parent documents
store = InMemoryStore()
id_key = "doc_id"

# The retriever (empty to start)
retriever = MultiVectorRetriever(
    vectorstore=vectorstore,
    docstore=store,
    id_key=id_key,
)


In [None]:
# ✅ Add texts and summaries (you probably already have this)
doc_ids = [str(uuid.uuid4()) for _ in texts]
summary_docs = [
    Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(summaries)
]
retriever.vectorstore.add_documents(summary_docs)
retriever.docstore.mset(list(zip(doc_ids, texts)))

'''# Add tables
table_ids = [str(uuid.uuid4()) for _ in tables]
summary_tables = [
    Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(summaries)
]
retriever.vectorstore.add_documents(summary_tables)
retriever.docstore.mset(list(zip(table_ids, tables)))

# Add image summaries
img_ids = [str(uuid.uuid4()) for _ in images]
summary_img = [
    Document(page_content=summary, metadata={id_key: img_ids[i]}) for i, summary in enumerate(summaries)
]
retriever.vectorstore.add_documents(summary_img)
retriever.docstore.mset(list(zip(img_ids, images)))'''

In [None]:
# Retrieve
docs = retriever.invoke(
    "What is an intrusion detection system (IDS)?"
)
print("\n📄 Retrieved Documents:")
for doc in docs:
    if hasattr(doc, "page_content"):
        print("-", doc.page_content)
    elif hasattr(doc, "document") and hasattr(doc.document, "page_content"):
        print("-", doc.document.page_content)
    else:
        print("-", str(doc))  # fallback
