# AI Chat with Tools using Mistral AI's Mixtral8x7B [Mixtral of Experts](https://mistral.ai/news/mixtral-of-experts/) Large Language Model

## This notebook downloads and initializes Mistal AI's Mixtral8x7B [Mixtral of Experts](https://mistral.ai/news/mixtral-of-experts/) using the [Huggingface transformers](https://huggingface.co/docs/transformers/index) package along with facilitating Tool Usage and Chat Session Logic.

The code is based on this [Pinecone Tutorial](https://www.pinecone.io/learn/mixtral-8x7b/) and extends it to include additional inference/memory improvements (4-bit/8-bit quantization and Flash Attention 2) along with full multi-step chat logic. Please check out their tutorial for a step-by-step walkthrough on the implementation.

### Tools

- "Calculator" - Enables Mixtral AI Assistant to perform math calculations by providing python code which is then run executed using exec() - **WARNING* Using exec() to execute arbitrary Python code can be dangerous, as it can execute any Python command. This might lead to security vulnerabilities, especially in a web-based environment like Google Colab. This notebook does not currently apply any sanitzation, filtering, or other safety measures to the code generated by the Mixtral model before code execution, so please be cautious and use at your own risk.**
- "Search" - Enables Mixtral AI Assistant to Search the web for real-time information using DuckDuckGo



# Set up

## 1) Select GPU Runtime
In the top left of Google Colab, select Runtime > Change runtime type > A100 GPU
- If you are using the free version of Google Colab, A100 GPUs may not be an option, but you might be able to use a different GPU if you use the "High-RAM" Runtime Setting along with 4-bit quantization (set `load_in_4_bit=True`), but I've only tested this code using an A100 so far. I've been able to get reasonably fast inference speeds using 4-bit quantization and Flash Attention 2 (set `use_flash_attention_2=True`) with the single 40G A100 provided by Google Colab, but I have not been able to successfully load the model in higher precision yet. If you'd like to use a runtime with better hardware, click on the resources dropdown in the top right corner (where your GPU/Runtime Type is displayed) and select "Connect to a custom GCE VM"


## 2) Install Dependencies
Run the cell below to install all required package dependencies. After installations are complete, the runtime will automatically be killed to restart the kernel so that the changes take effect. Once the runtime has been killed, you can move onto step 3 below to download, setup, and inference the Mixtral model (Do not re-run the installation cell)



In [None]:
# @title Install Huggingface & dependencies

from google.colab import output, files
import os

# Tooling Installaltions from the pinecone tutorial - https://www.pinecone.io/learn/mixtral-8x7b/
!pip install -qU \
    transformers==4.36.1 \
    accelerate==0.25.0 \
    duckduckgo_search==4.1.0
# Hugginface transformers module
!pip install -U transformers
# Enables 4-bit & 8-bit quantization
!pip install -U pip install bitsandbytes
# Enables Flash Attention 2
!pip install flash-attn --no-build-isolation
output.clear()

print("Huggingface Transformers module downloaded successfully! Restarting Runtime...")
os.kill(os.getpid(), 9)


# 3) Download and Initialize Mixtral Model

Select your model settings, and then run the "**Download & Initialize Model**" cell to load the Mixtral model. You should only run this cell once to avoid reloading the model (it takes a while), and then you can run the Run AI Chat cell below it as many times as you want to initialize new chat sessions.

### Settings
1)  `use_instruct_model: bool = True` - If True, uses the instruct-finetuned Mixtral-8x7B-Instruct-v0.1 model. If False, uses the Mixtral-8x7B-v0.1 base model.
- This notebook is designed for use with the Instruct model, so you will most likely get errors if you set `use_instruct_model: bool = False` as the base model is mainly just incorporated demonstration purposes. If you do want to use the base model, make sure to change the `sys_message: str = None` in the "Run Mixtral8x7B AI Chat with Tools" cell of Step 4 to your own text-completion prompt instead of None, which applies the default Instruct model with Tools system prompt (reference the default sys_message in the `first_prompt_instruction_format(query: str, sys_message: str = None)` function located in the "Download & Initialize Mixtral8x7B Model" cell in step 3 to ensure your new sys_message is compatible with this function)

2) `load_in_4_bit: bool = True` - If True, loads the model in reduced precision with 4-bit quantization using bitsandbytes to drastically reduce memory requirements and inference time

3) `load_in_8_bit: bool = False` - If True, loads the model in reduced precision with 8-bit quantization using bitsandbytes to significantly reduce memory requirements and inference time. If both load_in_4_bit and load_in_8_bit are True, it will default to loading in 4-bit precision

4) `use_flash_attention_2: bool = True` - Use Flash Attention 2 to significantly reduce inference time

---
### Text-Generation Args
The remaining settings are input arguments for Huggingface `transformers.pipeline()` text-generation pipeline. I recommend leaving them set to the default values unless you know what you're doing. For more details, please refer to the [Huggingface Pipelines Documentation](https://huggingface.co/docs/transformers/main_classes/pipelines)
- `temperature: float = 0.1` - 'Randomness' of outputs, 0.0 is the min and 1.0 the max
- `top_p: float = 0.15` - Select from top tokens whose probability add up to top_p
- `top_k: int = 0` - Select from top top_k tokens (because zero, relies on top_p)
- `do_sample: bool = True` - I'm not really sure what this is tbh, but Transformers was giving a warning saying that this needs to be set to True since top_k is set
- `max_new_tokens: int = 512` - Maximum number of tokens to generate per response
- `repetition_penalty: float = 1.1` - Penalizes the model for repitive text generation. If the output begins repeating, increase this parameter




In [None]:
# @title Download & Initialize Mixtral8x7B Model

# Model Settings
use_instruct_model = True # @param {type:"boolean"}
load_in_4_bit = True # @param {type:"boolean"}
load_in_8_bit = False # @param {type:"boolean"}
use_flash_attention_2 = True # @param {type:"boolean"}
# Huggingface transformers.pipeline() Args
temperature = 0.1 # @param {type:"number"}
top_p = 0.15 # @param {type:"number"}
top_k = 0 # @param {type:"integer"}
do_sample = True # @param {type:"boolean"}
max_new_tokens = 512 # @param {type:"integer"}
repetition_penalty = 1.1  # @param {type:"number"}


import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import datetime
from duckduckgo_search import DDGS
import os
from google.colab import output
import logging
import io

# Set the logging level to suppress warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
logging.getLogger('bitsandbytes').setLevel(logging.ERROR)


# Select Model Version
if use_instruct_model:
  model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
else:
  model_id = "mistralai/Mixtral-8x7B-v0.1"

# Load Tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
if load_in_4_bit:
  if use_flash_attention_2:
    # Load model in 4-bit precision with Flash Attention 2
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id,
                                  load_in_4bit=True,
                                  attn_implementation="flash_attention_2"
                                )
  else:
    # Load model in 8-bit precision
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id,
                                  load_in_4bit=True
                                )

elif load_in_8_bit:
  if use_flash_attention_2:
    # Load model in 8-bit precision with Flash Attention 2
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id,
                                  load_in_8bit=True,
                                  attn_implementation="flash_attention_2"
                                )
  else:
    # Load model in 8-bit precision
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id,
                                  load_in_8bit=True
                                )

else:
  if use_flash_attention_2:
    # Load model in full precision with Flash Attention 2
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id,
                                  attn_implementation="flash_attention_2"
                                )
  else:
    # Load model in full precision
    model = AutoModelForCausalLM.from_pretrained(
                                  pretrained_model_name_or_path = model_id
                                )



# Create Huggingface Text Generation Pipeline
generate_text = transformers.pipeline(
    model=model,
    tokenizer=tokenizer,
    return_full_text=False,  # if using langchain set True
    task="text-generation",
    # we pass model parameters here too
    temperature=temperature,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
    top_p=top_p,  # select from top tokens whose probability add up to 15%
    top_k=top_k,  # select from top 0 tokens (because zero, relies on top_p)
    do_sample=do_sample, # Transformers warning says I need to set this to True since top_k is set
    max_new_tokens=max_new_tokens,  # max number of tokens to generate in the output
    repetition_penalty=repetition_penalty  # if output begins repeating increase
)



# Prints debug_message only if env variable "DEBUG".upper() == True
def print_debug_message(debug_message: str):
  if os.environ.get("DEBUG") and os.environ.get("DEBUG").upper() == "TRUE":
    print(f"\n\n\n##### Debug Message ######\n {debug_message}\n##### End Debug Message ######\n\n\n")

# Sanitize generated_text string for when the model directly
# returned its response message rather than a valid action dict json
# string for its response as intended so we have to format it manually
# the action dict json string manually and must remove any potential newlines
# and stuff from its message to ensure json parsability
def sanitize_text_for_json(text: str):
    # Strip whitespace and control characters from both ends of the text
    sanitized_text = text.strip()

    # Replace internal control characters like newlines with spaces
    sanitized_text = sanitized_text.replace("\n", " ").replace("\r", " ").replace("\t", " ")

    return sanitized_text

# Set up first user query formatted with initial system prompt (tool
# use instructions included in default system prompt)
def first_prompt_instruction_format(query: str, sys_message: str = None):

    if sys_message is None:
      sys_message = """You are a helpful AI assistant, you are an agent capable of using a variety of tools to answer a question. Here are a few of the tools available to you:

      - Calculator: the calculator should be used whenever you need to perform a calculation, no matter how simple. It uses Python so make sure to write complete Python code required to perform the calculation required and make sure the Python returns your answer to the `output` variable.
      - Search: the search tool should be used whenever you need to find information. It can be used to find information about everything
      - Final Answer: the final answer tool must be used to respond to the user. You must use this when you have decided on an answer.

      To use these tools you must always respond in JSON format containing `"tool_name"` and `"input"` key-value pairs. For example, to answer the question, "what is the square root of 51?" you must use the calculator tool like so:

      ```json
      {
          "tool_name": "Calculator",
          "input": "from math import sqrt; output = sqrt(51)"
      }
      ```

      Or to answer the question "who is the current president of the USA?" you must respond:

      ```json
      {
          "tool_name": "Search",
          "input": "current president of USA"
      }
      ```

      Remember, even when answering to the user, you must still use this JSON format! If you'd like to ask how the user is doing you must write:

      ```json
      {
          "tool_name": "Final Answer",
          "input": "How are you today?"
      }
      ```

      Let's get started. The users query is as follows.
      """
      # note, don't "</s>" to the end

    return f'<s> [INST] {sys_message} [/INST]\nUser: {query}\nAssistant: '#```json\n{{\n"tool_name": '


def extract_first_json(text: str, debug: bool = False):
    # Find the first and last braces to extract the JSON object
    # This will be more robust against irregular formatting
    start_index = text.find('{')
    end_index = text.rfind('}')  # Get the last closing brace

    if start_index == -1 or end_index == -1 or end_index < start_index:
        return None  # Return None if valid JSON braces are not found

    # Extract the substring that forms the JSON object
    json_str = text[start_index:end_index + 1]

    # Replace newline characters and other potential issues
    #json_str = json_str.replace('\n', '\\n').replace('\r', '\\r').replace('\t', '\\t')

    return json_str

def format_output(text: str):
    print_debug_message(f"format_output(): initial text: {text}\n")
    full_json_str = extract_first_json(text)

    if full_json_str is None:
        print_debug_message(f"format_output(): No valid JSON found calling extract_first_json() with text - {text}")
        return None

    print_debug_message(f"format_output(): full_json_str after extract_first_json()call - {full_json_str}\n")

    try:
        return json.loads(full_json_str)
    except json.JSONDecodeError as e:
        print_debug_message(f"format_output(): Error decoding JSON from json.loads(full_json_str) with text - {text}\nand full_json_str - {full_json_str}\nError Message - {e}")
        return None  # Handle the error as needed




# Processes the action dict created by the format_output() function to execute
# the selected tool or provide the AI's final response based on the value of
# action["tool_name"] (if tool name isn't recognized, it will assume it's a
# "Final Answer" action)
def use_tool(action: dict):
    tool_name = action["tool_name"]
    if tool_name == "Final Answer":
        is_tool_response = False
        return f"\n\nAI Assistant: {action['input']}", is_tool_response
    elif tool_name == "Calculator":
        print("\nUsing Calculator...\n")
        # Create a dictionary to serve as a local namespace for exec
        local_namespace = {}

        # Execute the code within the local namespace
        exec(action["input"], {}, local_namespace)

        # Access the value of 'result' from the local namespace
        exec_result = local_namespace.get('output', None)
        is_tool_response = True
        return f"\n\nTool Output: {exec_result}", is_tool_response
    elif tool_name == "Search":
        print(f"\nSearching the Web for {action['input']}...\n")
        contexts = []
        with DDGS() as ddgs:
            results = ddgs.text(
                action["input"],
                region="wt-wt", safesearch="on",
                max_results=3
            )
            for r in results:
                contexts.append(r['body'])
        info = "\n---\n".join(contexts)
        is_tool_response = True
        return f"\n\nTool Output: {info}", is_tool_response
    else:
        # otherwise just assume final answer
        is_tool_response = False
        return f"\n\nAI Assistant: {action['input']}", is_tool_response

# Takes a full input_prompt (with prior conversation history included),
# queries the Mixtral model, processes the text-generation response, and
# itteratively executes Tool Usage until a "Final Answer" answer is recieved.
# Then, returns both the response_message to be displayed and the
# new input_prompt with all new messages appended
def handle_message(input_prompt):

  response = generate_text(input_prompt)
  generated_text = response[0]['generated_text']
  print_debug_message(f"handle_message(): initial generated_text: {generated_text}\n")

  # If it fails to load json, it probably just responded
  # directly without the dict, so build a dict out of it
  # in a try/except clause
  try:
      action = format_output(generated_text)
      if action is None:
        # if action is None, the formatting was wrong, meaning the model probably didn't
        # wrap its response message in an action dict, so raise an error to let except
        # block reformat it into a "Final Answer" action dict json string
        raise Exception("Failed to parse json from generated_text to creare action dict.")
      response_message, is_tool_response = use_tool(action)
  except Exception as e:

      # If formatting the output and using the tool fails, then the model
      # probably didn't use action dict json string formatting, so just
      # assume its message was intended to be a direct message to the user
      # and reformat it as a "Final Answer" action
      print_debug_message(f"handle_message(): Exception Block triggered for first text-generation's format_output()/use_tool() call! Exception - {e}")

      # Since the model probably sent direct response message rather than formatting it as
      # an action dict json string, we need to remove any newlines, whitespace, etc
      # to ensure the new action dict input value is json parsable in the new_generated_text
      # json string
      generated_text = sanitize_text_for_json(text = generated_text)
      # Now format the response into a valid action dict json string
      new_generated_text = """
      ```json
      {
          "tool_name": "Final Answer",
          "input": """ + f"\"{generated_text}\""+ """
      }
      ```
      """
      generated_text = new_generated_text
      action = format_output(generated_text)
      response_message, is_tool_response = use_tool(action)
  # Add Initial Assistant Response to the prompt
  input_prompt += generated_text

  # If response is tool response, add the response message
  # to the prompt and query again, return the full new
  # input_prompt along with the response_message to be displayed
  # to the user
  while is_tool_response:
    input_prompt += response_message + "\n\nAI Assistant: "
    response = generate_text(input_prompt)
    generated_text = response[0]['generated_text']
    print_debug_message(f"handle_message(): while is_tool_response generated_text : {generated_text}")

    # If it fails to load json, it probably just responded
    # directly without the dict, so build a dict out of it
    # in a try/except clause
    try:
        action = format_output(generated_text)
        if action is None:
          # if action is None, the formatting was wrong, meaning the model probably didn't
          # wrap its response message in an action dict, so raise an error to let except
          # block reformat it into a "Final Answer" action dict json string
          raise Exception("Failed to parse json from generated_text to creare action dict.")
        response_message, is_tool_response = use_tool(action)
    except Exception as e:

        # If formatting the output and using the tool fails, then the model
        # probably didn't use action dict json string formatting, so just
        # assume its message was intended to be a direct message to the user
        # and reformat it as a "Final Answer" action
        print_debug_message(f"handle_message(): Exception Block triggered for 'while is_tool_response' loop format_output()/use_tool() call! Exception - {e}")

        # Since the model probably sent direct response message rather than formatting it as
        # an action dict json string, we need to remove any newlines, whitespace, etc
        # to ensure the new action dict input value is json parsable in the new_generated_text
        # json string
        generated_text = sanitize_text_for_json(text = generated_text)
        # Now format the response into a valid action dict json string
        new_generated_text = """
        ```json
        {
            "tool_name": "Final Answer",
            "input": """ + f"\"{generated_text}\""+ """
        }
        ```
        """
        generated_text = new_generated_text
        action = format_output(generated_text)
        response_message, is_tool_response = use_tool(action)
    # Add the new is_tool_response_loop generated_text to the prompt
    input_prompt += generated_text



  return response_message, input_prompt

# Initializes AI Chat orchestrating user inputs and multi-step chat logic
def run_ai_chat(sys_message: str = None, max_user_messages: int = 10, chat_log_dir: str = "logs/"):

    # Create Chat Log Filepath to save convo history
    now = datetime.datetime.now()
    formatted_datetime = now.strftime("%Y-%m-%d_%H-%M-%S")

    if not os.path.exists(chat_log_dir):
      os.makedirs(chat_log_dir)
    chat_log_file = f"{chat_log_dir}User_Conversation_{formatted_datetime}.txt"

    # Begin AI Assistant Chat
    output.clear()
    print("\n\nBeginning AI Assistant Chat. Type 'exit' at any time to end the chat\n\n\n")
    print("\n\n")
    user_message = input("User: ")
    print("\n\n")
    input_prompt = first_prompt_instruction_format(user_message, sys_message = sys_message)
    response_message, input_prompt = handle_message(
                                          input_prompt
                                        )

    print(response_message)
    print("\n\n")
    # Save Chat Log
    with open(chat_log_file, 'w', encoding='utf-8') as file:
        file.write(input_prompt)
    # Start at 1, not 0, cuz they already sent the first message
    for i in range(1, max_user_messages):
      print("\n\n")
      user_message = input("User: ")
      print("\n\n")
      input_prompt += "\nUser: " + user_message

      if user_message.upper() == 'EXIT':
        input_prompt += "\n\n\nChat Ended. Have a great day! (:"
        print("\n\n\nChat Ended. Have a great day! (:\n\n")
        # Save Chat Log
        with open(chat_log_file, 'w', encoding='utf-8') as file:
            file.write(input_prompt)
        return input_prompt
      input_prompt += f"\nUser: {user_message}"
      response_message, input_prompt = handle_message(
                                          input_prompt
                                        )

      print(response_message)
      print("\n\n")
      # Save Chat Log
      with open(chat_log_file, 'w', encoding='utf-8') as file:
            file.write(input_prompt)

    print("\n\n\nChat Ended. Have a great day! (:")
    # Save Chat Log
    with open(chat_log_file, 'w', encoding='utf-8') as file:
            file.write(input_prompt)
    return input_prompt



print("\n\n\nMixtral-8x7B-Instruct Model Loaded Successfully!")





# 4) Run Mixtral AI Chat

After the model has been downloaded & initialized, run the cell below to start your chat session
### Settings
- `max_user_messages: int = 20` - Maximum number of messages the user can send in a given AI Chat session. This is used to prevent conversations from becoming too long to fit within the model's maximum sequence length
- `debug_mode: bool = False` - Print debug messages


In [None]:
# @title Run Mixtral8x7B AI Chat with Tools

# How many total messages the user can send in a given AI Chat session
# before it automatically ends the conversation. This prevents
# conversations from becoming too long to fit within the model's maximum
# sequence length
max_user_messages = 20 # @param {type:"integer"}
debug_mode = False # @param {type:"boolean"}

# Set sys_message to None to use default tool-enabled system prompt (default
# system prompt is located in the first_prompt_instruction_format() function
# of the "Download & Initialize Mixtral8x7B Model" cell above)
sys_message = None, # Setting "None"

# Set "DEBUG" env variable to toggle print_debug_message() function printing
if debug_mode:
  os.environ['DEBUG'] = 'TRUE'
else:
  os.environ['DEBUG'] = 'FALSE'

# Creates a chat_log_dir directory to save a .txt file of the full
# chat conversation for each AI Chat session (including system prompt,
# Tool Calls, Tool Outputs, etc) to help with debugging
chat_log_dir = "logs/"

# Set the logging level to suppress warnings
logging.getLogger('transformers').setLevel(logging.ERROR)
logging.getLogger('bitsandbytes').setLevel(logging.ERROR)

input_prompt = run_ai_chat(
                  sys_message = sys_message,
                  max_user_messages = max_user_messages,
                  chat_log_dir = chat_log_dir
                )
