In [6]:
from loader.daily_loader import load_daily_paper
from chater.chat import mapreduce
import os
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from langchain.vectorstores import FAISS
from langchain.vectorstores import FAISS
from langchain_core.output_parsers import StrOutputParser
from langchain_nvidia_ai_endpoints import ChatNVIDIA
import faiss
import json

from vec_generate.arxiv_generate import arxiv_generate

In [14]:
#初始化部分
os.environ["NVIDIA_API_KEY"] = "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"
model = ChatNVIDIA(model="ai-llama3-8b")#.bind(max_tokens=4096)
embedder = NVIDIAEmbeddings(model="ai-embed-qa-4", truncate="END")
cache = {}
try:
    cache = json.load(open("cache.json"))  # 读取用户有过哪些向量库   里面存的应该是tuple (vectorstore_path, initial_msg)
except:
    cache = {}
    json.dump(cache, open("cache.json", "w"))

In [15]:
len(embedder.embed_query("test"))

1024

这里先不给用户造自己的库的功能 只有daily chat和 awesome chat 两个功能  本地有个csv表格记录已经做了哪些库

In [16]:


def daily_chat(date, cat, model, embedder, cache={}, verbose=False):

    '''
    接受日期 类别 模型和cache, 返回一个向量库路径和一个开场词
    如果用户指定verbose为True, 则当天的arxiv表格会保存在该库本地
    '''

    if date+cat in cache:
        return cache[date+cat]
    
    else:
        daily = load_daily_paper(date)

        if(len(daily) == 0):
            print("No papers found for this date and category, please retry")
            return None, None

        summary, detail = mapreduce(model, daily)
        if verbose:
            daily.to_csv(f"{date}_{cat}.csv")  # 保存当天的arxiv表格
        
        path_name = 'local_embed/'+date+cat
        arxiv_generate(detail['id'], embedder, path_name)
        
        cache[date+cat] = (path_name, summary)
        json.dump(cache, open("cache.json", "w"))

        return path_name, summary
        
        

In [17]:
from loader.awesome_loader import load_from_awesome
def awesome_chat(md_path, embedder, cache={}):
    '''
    接受md路径和cache, 返回一个向量库路径和一个开场词
    '''
    if md_path in cache:
        return cache[md_path]
    else:
        with open(md_path, 'r') as f:
            text = f.read()
            
        path_name = 'local_embed/'+md_path.split('.')[0]
        
        paper_list = load_from_awesome(text)

        msg = (
    "你好，我是awesome-parser， 一个帮助用户解析awesome-list的工具。"
    f"我从您提供的文档{md_path}中解析到了{len(paper_list)}篇paper，我可以帮助您吗")

        arxiv_generate(paper_list, embedder, path_name)
        cache[md_path] = (path_name, msg)
        json.dump(cache, open("cache.json", "w"))

        return path_name, msg
    

In [22]:
path, init_msg = daily_chat("20240707", "cs.CR", model, embedder, cache, verbose=True)

In [18]:
path, init_msg = awesome_chat("prompt_injection.md", embedder, cache)

Loading Documents
Cleaning Documents
Chunking Documents
Creating Vectorstores


In [20]:
print(path, init_msg)

local_embed/prompt_injection 你好，我是awesome-parser， 一个帮助用户解析awesome-list的工具。我从您提供的文档prompt_injection.md中解析到了16篇paper，我可以帮助您吗


In [21]:
vecstores = [FAISS.load_local(folder_path=path, embeddings=embedder,allow_dangerous_deserialization=True)]

In [22]:
from vec_generate.arxiv_generate import aggregate_vstores

docstore = aggregate_vstores(vecstores, embedder)
print(f"Constructed aggregate docstore with {len(docstore.docstore._dict)} chunks")

Constructed aggregate docstore with 3470 chunks


In [10]:
from vec_generate.arxiv_generate import default_FAISS
convstore = default_FAISS(embedder)

In [11]:
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableLambda
from functools import partial
from langchain.document_transformers import LongContextReorder
from operator import itemgetter
from langchain_core.runnables.passthrough import RunnableAssign

    
def save_memory_and_get_output(d, vstore):
    """Accepts 'input'/'output' dictionary and saves to convstore"""
    vstore.add_texts([
        f"User previously responded with {d.get('input')}",
        f"Agent previously responded with {d.get('output')}"
    ])
    return d.get('output')

initial_msg = init_msg

chat_prompt = ChatPromptTemplate.from_messages([("system",
    "You are a document chatbot. Help the user as they ask questions about documents."
    " User messaged just asked: {input}\n\n"
    " From this, we have retrieved the following potentially-useful info: "
    " Conversation History Retrieval:\n{history}\n\n"
    " Document Retrieval:\n{context}\n\n"
    " (Answer only from retrieval. Only cite sources that are used. Make your response conversational.Reply must more than 100 words)"
), ('user', '{input}')])

## Utility Runnables/Methods
def RPrint(preface=""):
    """Simple passthrough "prints, then returns" chain"""
    def print_and_return(x, preface):
        print(f"{preface}{x}")
        return x
    return RunnableLambda(partial(print_and_return, preface=preface))

def docs2str(docs, title="Document"):
    """Useful utility for making chunks into context string. Optional, but useful"""
    out_str = ""
    for doc in docs:
        doc_name = getattr(doc, 'metadata', {}).get('Title', title)
        if doc_name:
            out_str += f"[Quote from {doc_name}] "
        out_str += getattr(doc, 'page_content', str(doc)) + "\n"
    return out_str

## 将较长的文档重新排序到输出文本的中心， RunnableLambda在链中运行无参自定义函数 ，长上下文重排序（LongContextReorder）
long_reorder = RunnableLambda(LongContextReorder().transform_documents)

retrieval_chain = (
    {'input' : (lambda x: x)}
    | RunnableAssign({'history' : itemgetter('input') | convstore.as_retriever() | long_reorder | docs2str})
    | RunnableAssign({'context' : itemgetter('input') | docstore.as_retriever()  | long_reorder | docs2str})
    | RPrint()
)
stream_chain = chat_prompt | model | StrOutputParser()

def chat_gen(message, history=[], return_buffer=True):
    buffer = ""
    ##首先根据输入的消息进行检索
    retrieval = retrieval_chain.invoke(message)
    line_buffer = ""

    ## 然后流式传输stream_chain的结果
    for token in stream_chain.stream(retrieval):
        buffer += token
        ## 优化信息打印的格式
        if not return_buffer:
            line_buffer += token
            if "\n" in line_buffer:
                line_buffer = ""
            if ((len(line_buffer)>84 and token and token[0] == " ") or len(line_buffer)>100):
                line_buffer = ""
                yield "\n"
                token = "  " + token.lstrip()
        yield buffer if return_buffer else token

    ##最后将聊天内容保存到对话内存缓冲区中
    save_memory_and_get_output({'input':  message, 'output': buffer}, convstore)

In [12]:
import gradio as gr
chatbot = gr.Chatbot(value = [[None, initial_msg]])
demo = gr.ChatInterface(chat_gen, chatbot=chatbot).queue()

try:
    demo.launch(debug=True, share=False, show_api=False, server_port=5000, server_name="0.0.0.0")
    demo.close()
except Exception as e:
    demo.close()
    print(e)
    raise e

Running on local URL:  http://0.0.0.0:5000

To create a public link, set `share=True` in `launch()`.


{'input': '什么是prompt注入攻击', 'history': '', 'context': '[Quote from Prompt Injection Attacks in Defended Systems] competition aims to foster innovation and collaboration in the\nfield of LLM security.\nA. Defense Prompt\nThe defense prompt is appended to the general system\nprompt and the secret-carrying sentence in the system prompt\n[Quote from Prompt Injection Attacks in Defended Systems] to explain when a question does not make sense or is not\nfactually coherent, and to avoid sharing false information when\nthe answer is unknown. The defense prompt is guaranteed to be\nthe first message in any chat with this defense and is restricted\n[Quote from Prompt Injection Attacks in Defended Systems] use the fields Age, Value, Name, and already ask to fill the\nValue field with "your value."\nThe following are several of the most successful prompts in\nEnglish that were used:\n• An attack using a json object\n[Quote from Prompt Injection Attacks in Defended Systems] or the first message of t