<a href="https://colab.research.google.com/github/russellemergentai/MistralInstruct/blob/main/Langchain_Mistral_Agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Install required packages
!pip install langchain
!pip install langchain-community
!pip install langchain-chroma
!pip install accelerate
!pip install bitsandbytes
!pip install wikipedia
!pip install transformers #<= clash on numpy, kernel restart
!pip install langchain-huggingface #<= clash on numpy, kernel restart

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

# Import transformers models and utilities
from transformers import pipeline
from transformers.models.mistral.modeling_mistral import MistralForCausalLM
from transformers.models.llama.tokenization_llama_fast import LlamaTokenizerFast
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Import LangChain modules and utilities
from langchain.tools import WikipediaQueryRun, BaseTool
from langchain.agents import Tool
from langchain_community.utilities import WikipediaAPIWrapper
from langchain.llms.base import LLM
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain.agents import create_json_chat_agent, AgentExecutor
from langchain.memory import ConversationBufferMemory

from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFacePipeline
from langchain_chroma import Chroma
from langchain.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.storage import InMemoryByteStore
from pathlib import Path

# Import core libraries and dependencies
import numexpr as ne
import os, uuid, torch
from typing import Optional, List, Mapping, Any

#login
from google.colab import drive
drive.mount('/content/drive')

from huggingface_hub import login
from google.colab import userdata
# load model and tokenizer
model_name = "mistralai/Mistral-7B-Instruct-v0.3"
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, quantization_config=quantization_config, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)

# wrap the LLNM
class CustomLLMMistral(LLM):
    model: MistralForCausalLM
    tokenizer: LlamaTokenizerFast

    @property
    def _llm_type(self) -> str:
        return "custom"

    def _call(self, prompt: str, stop: Optional[List[str]] = None,
        run_manager: Optional[CallbackManagerForLLMRun] = None) -> str:

        messages = [
         {"role": "user", "content": prompt},
        ]

        encodeds = self.tokenizer.apply_chat_template(messages, return_tensors="pt")
        model_inputs = encodeds.to(self.model.device)

        generated_ids = self.model.generate(model_inputs, max_new_tokens=512, do_sample=True,
                                            pad_token_id=self.tokenizer.eos_token_id, top_k=4, temperature=0.7)

        decoded = self.tokenizer.batch_decode(generated_ids)

        output = decoded[0].split("[/INST]")[1].replace("</s>", "").strip()

        if stop is not None:
          for word in stop:
            output = output.split(word)[0].strip()

        # Mistral 7B sometimes fails to properly close the Markdown Snippets.
        # If they are not correctly closed, Langchain will struggle to parse the output.
        while not output.endswith("```"):
          output += "`"

        return output

    @property
    def _identifying_params(self) -> Mapping[str, Any]:
        return {"model": self.model}


llm = CustomLLMMistral(model=model, tokenizer=tokenizer)

### Tools

In [None]:
wikipedia = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=2500))

wikipedia_tool = Tool(
    name="wikipedia",
    description="Never search for more than one concept at a single step. If you need to compare two concepts, search for each one individually. Syntax: string with a simple concept",
    func=wikipedia.run
)

class Calculator(BaseTool):
    name: str = "calculator"
    description: str = "Use this tool for math operations. It requires numexpr syntax. Use it always you need to solve any math operation. Be sure syntax is correct."

    def _run(self, expression: str):
      try:
        return ne.evaluate(expression).item()
      except Exception:
        return "This is not a numexpr valid syntax. Try a different syntax."

    def _arun(self, radius: int):
        raise NotImplementedError("This tool does not support async")

calculator_tool = Calculator()


def create_multivector_directory_retriever(directory_path):

    parent_splitter = RecursiveCharacterTextSplitter(chunk_size=500) #A
    child_splitter = RecursiveCharacterTextSplitter(chunk_size=250) #B

    model_path = "intfloat/e5-large-unsupervised"

    embeddings = HuggingFaceEmbeddings(
        model_name=model_path,
        model_kwargs={'device': 'cuda'},
        encode_kwargs={'normalize_embeddings': False}
    )

    child_chunks_collection = Chroma(
        collection_name="uk_child_chunks",
        embedding_function=embeddings,
    )

    child_chunks_collection.reset_collection()

    doc_byte_store = InMemoryByteStore()
    doc_key = "doc_id"

    multi_vector_retriever = MultiVectorRetriever(
        vectorstore=child_chunks_collection,
        byte_store=doc_byte_store
    )

    all_documents = []

    for file_path in Path(directory_path).rglob('*'):
        if file_path.is_file():
            loader = TextLoader(str(file_path), encoding='UTF-8')
            documents = loader.load()
            all_documents.extend(documents)

    coarse_chunks = parent_splitter.split_documents(all_documents)

    coarse_chunks_ids = [str(uuid.uuid4()) for _ in coarse_chunks]
    all_granular_chunks = []

    for i, coarse_chunk in enumerate(coarse_chunks):
        coarse_chunk_id = coarse_chunks_ids[i]
        granular_chunks = child_splitter.split_documents([coarse_chunk])

        for granular_chunk in granular_chunks:
            granular_chunk.metadata[doc_key] = coarse_chunk_id
            all_granular_chunks.extend(granular_chunks)

    multi_vector_retriever.vectorstore.add_documents(all_granular_chunks)
    multi_vector_retriever.docstore.mset(list(zip(coarse_chunks_ids, coarse_chunks)))

    return multi_vector_retriever


  # retrieve from data directory
def retrieval_multivector_query_data(expression: str):

  # It's important to note that to effectively prompt the Mistral 7B Instruct and get optimal outputs,
  # it's recommended to use the following chat template:
  # <s>[INST] Instruction [/INST] Model answer</s>[INST] Follow-up instruction [/INST]
  prompt_template="""
  <s>
  [INST]
  Below is an instruction that describes a task. Write a response that appropriately completes the request.
  {query}
  [/INST]
  </s>
  [INST]Keep your response succinct.[/INST]
  """

  path="/content/drive/MyDrive/Target"

  retriever = create_multivector_directory_retriever(path)

  common_params = {
    'max_length': 512,
    'eos_token_id': tokenizer.eos_token_id,
  }

  # Create the pipeline for text generation with output length constraint
  pipelineQuery = pipeline(
      "text-generation",
      model=model,
      tokenizer=tokenizer,
      **common_params,
      max_new_tokens=512
  )

  llmPipelineQuery = HuggingFacePipeline(pipeline=pipelineQuery, model_kwargs={"temperature": 0.1})
  qa = RetrievalQA.from_chain_type(llm=llmPipelineQuery, retriever=retriever, return_source_documents=False)
  result = qa.run({"query": expression})

  del pipelineQuery
  del llmPipelineQuery
  del qa
  del retriever
  import gc
  gc.collect()

  return result


class RAGQuery(BaseTool):
    name: str = "rag"
    description: str = "Use this tool for retrieval augmented generation rag operations from my personal files. \
    It requires a query. \
    Use it to always when rag is requested or the subject is: Murex; Summit; STF."

    def _run(self, expression: str = ""):
      try:
        return retrieval_multivector_query_data(expression)
      except Exception as e:
        s = f"An exception occurred: {e}"
        return s

    def _arun(self, radius: int):
        raise NotImplementedError("This tool does not support async")

rag_tool = RAGQuery()


tools = [wikipedia_tool, calculator_tool, rag_tool]


### Prompt

In [None]:
system="""
You are designed to solve tasks. Each task requires multiple steps that are represented by a markdown code snippet of a json blob.
The json structure should contain the following keys:
thought -> your thoughts
action -> name of a tool
action_input -> parameters to send to the tool

These are the tools you can use: {tool_names}.

These are the tools descriptions:

{tools}

If you have enough information to answer the query use the tool "Final Answer". Its parameters is the solution.
If there is not enough information, keep trying.
"""

human="""
Add the word "STOP" after each markdown snippet. Example:

```json
{{"thought": "<your thoughts>",
 "action": "<tool name or Final Answer to give a final answer>",
 "action_input": "<tool parameters or the final output"}}
```
STOP

This is my query="{input}". Write only the next step needed to solve it.
Your answer should be based in the previous tools executions, even if you think you know the answer.
Remember to add STOP after each snippet.

These were the previous steps given to solve this query and the information you already gathered:
"""

prompt = ChatPromptTemplate.from_messages(
    [
        ("system", system),
        MessagesPlaceholder("chat_history", optional=True),
        ("human", human),
        MessagesPlaceholder("agent_scratchpad")
    ]
)

### Agents

In [None]:
agent = create_json_chat_agent(
    tools = tools,
    llm = llm,
    prompt = prompt,
    stop_sequence = ["STOP"],
    template_tool_response = "{observation}"
)

memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True, memory=memory) #omit memory to be stateless

def main():

    while True:
        query = input("Enter query: ").lower()


        if query=="x":
            print("Exiting.")
            break

        agent_executor.invoke({"input": query})

if __name__ == "__main__":
    main()