In [16]:
%%capture
%pip install -q langchain langchain-community langchain-nvidia-ai-endpoints gradio rich
%pip install -q arxiv pymupdf faiss-cpu

In [32]:
from google.colab import userdata

import json

from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.document_loaders import ArxivLoader
from langchain.document_transformers import LongContextReorder
from langchain_core.runnables import RunnableLambda
from langchain_core.runnables.passthrough import RunnableAssign

from faiss import IndexFlatL2
from langchain_community.docstore.in_memory import InMemoryDocstore

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser

import gradio as gr

from functools import partial
from operator import itemgetter

from rich.console import Console
from rich.style import Style
from rich.theme import Theme

console = Console()
base_style = Style(color="#76B900", bold=True)
pprint = partial(console.print, style=base_style)

In [11]:
# NVIDIAEmbeddings.get_available_models()
embedder = NVIDIAEmbeddings(
    model="nvidia/nv-embed-v1",
    api_key=userdata.get('NV-EMD-KEY'),
    truncate="END")

In [14]:
# ChatNVIDIA.get_available_models()
instruct_llm = ChatNVIDIA(
    model="mistralai/mixtral-8x22b-instruct-v0.1",
    api_key=userdata.get('MX-INS-KEY')
    )

## Loading papers

In [None]:
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=1000,
    chunk_overlap=100,
    separators=["\n\n", "\n", ".", ";", ",", " "],
)

print("Loading Documents...")
docs = []
with open('papers.txt', 'r') as file:
    for line in file:
        docs.append(ArxivLoader(query=line).load())

# Cut off references
for doc in docs:
    content = json.dumps(doc[0].page_content)
    if "References" in content:
        doc[0].page_content = content[:content.index("References")]
print("...Done!")

# Chunk the documents and filter out stubs
print("Chunking Documents...")
docs_chunks = [text_splitter.split_documents(doc) for doc in docs]
docs_chunks = [[c for c in dchunks if len(c.page_content) > 200] for dchunks in docs_chunks]
print("...Done!")

print("Creating Catalog...")
# Add a catalog chunk
doc_string = "Available Documents:"
doc_metadata = []
for chunks in docs_chunks:
    metadata = getattr(chunks[0], 'metadata', {})
    doc_string += "\n - " + metadata.get('Title')
    doc_metadata += [str(metadata)]

extra_chunks = [doc_string] + doc_metadata
print("...Done!")

pprint(doc_string, '\n')

Loading Documents...
...Done!
Chunking Documents...
...Done!
Creating Catalog...
...Done!


## Construct Document Vector Store

In [29]:
%%time
print("Constructing Vector Stores...")
vecstores = [FAISS.from_texts(extra_chunks, embedder)]
vecstores += [FAISS.from_documents(doc_chunks, embedder) for doc_chunks in docs_chunks]
print("...Done!")

Constructing Vector Stores...
...Done!
CPU times: user 1.56 s, sys: 172 ms, total: 1.73 s
Wall time: 49 s


In [31]:
embed_dims = len(embedder.embed_query("test"))
def default_FAISS():
    """Make an empty FAISS vecstore"""
    return FAISS(
        embedding_function=embedder,
        index=IndexFlatL2(embed_dims),
        docstore=InMemoryDocstore(),
        index_to_docstore_id={},
        normalize_L2=False,
    )

def aggregate_vstores(vectorstores):
    """Initialize an empty FAISS Index and merge others into it."""
    agg_vstore = default_FAISS()
    for vstore in vecstores:
        agg_vstore.merge_from(vstore)
    return agg_vstore

print("Merging Vectore Stores...")
docstore = aggregate_vstores(vecstores)
print("...Done")
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks.")

Merging Vectore Stores...
...Done
Constructed aggregate docstore with 629 chunks.


## RAG Chain

In [33]:
# utilities

def RPrint(preface=""):
    """Simple passthrough "prints, then returns" chain"""
    def print_and_return(x, preface):
        if preface: print(preface, end="")
        pprint(x)
        return x
    return RunnableLambda(partial(print_and_return, preface=preface))

def docs2str(docs, title="Document"):
    """Useful utility for making chunks into context string."""
    out_str = ""
    for doc in docs:
        doc_name = getattr(doc, 'metadata', {}).get('Title', title)
        if doc_name:
            out_str += f"[Quote from {doc_name}] "
        out_str += getattr(doc, 'page_content', str(doc)) + "\n"
    return out_str

# Reorder longer documents to center of output text
long_reorder = RunnableLambda(LongContextReorder().transform_documents)


def save_memory_and_get_output(d, vstore):
    """Accepts {input,output} dict and saves to vstore"""
    vstore.add_texts([
        f"User previously responded with {d.get('input')}",
        f"Agent previously responded with {d.get('output')}"
    ])
    return d.get('output')

In [37]:
# conversation vecstore
convstore = default_FAISS()

chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked: {input}\n\n"
    " From this, we have retrieved the following potentially-useful info: "
    " Conversation History Retrieval:\n{history}\n\n"
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.)"
), ('user', '{input}')])

## -> {input, history, context}
stream_chain = chat_prompt | instruct_llm | StrOutputParser()

retrieval_chain = (
    {'input': (lambda x: x)}
    | RunnableAssign({'history' : itemgetter("input") | convstore.as_retriever() | long_reorder | docs2str})
    | RunnableAssign({'context' : itemgetter("input") | docstore.as_retriever() | long_reorder | docs2str})
)

def chat_gen(message, history=[], return_buffer=True):
    buffer = ""

    retrieval = retrieval_chain.invoke(message)
    line_buffer = ""

    for token in stream_chain.stream(retrieval):
        buffer += token
        yield buffer if return_buffer else token

    save_memory_and_get_output({'input': message, 'output': buffer}, convstore)

In [38]:
test_question = "Tell me about sDREAMER"

for response in chat_gen(test_question, return_buffer=False):
    print(response, end='')

The sDREAMER is a model proposed in a study which uses a self-distilled Mixture-of-Modality-Experts (MoME) Transformer for automatic sleep staging. It's a multi-modal learning framework that handles both single-channel and multi-channel inputs. The sDREAMER model emphasizes cross-modality interaction and per-channel performance, which results in high-quality staging results on the sleep staging task.

The architecture of the sDREAMER model operates at two levels: epoch and sequence. At the epoch level, it captures epoch-level contexts using an epoch-level MoME module. At the sequence level, a sequence-level MoME transformer captures sequence-level contexts. Even though both levels share the same network architecture, the information flow differs between them during training and inference stages.

The sDREAMER model outperforms not only machine-learning-based methods but also deep-learning-based methods on both epoch and sequence levels. The study also tested sDREAMER's ability to gener

In [39]:
initial_msg = (
    "Hello! I am a document chat agent here to help the user!"
    f" I have access to the following documents: {doc_string}\n\nHow can I help you?"
)

chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()

try:
    demo.launch(debug=True, share=True, show_api=False)
    demo.close()
except Exception as e:
    demo.close()
    print(e)
    raise e



Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://7203dd3b0e41862e47.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://7203dd3b0e41862e47.gradio.live
Closing server running on port: 7860
