# ReAct implementation using guidance

# Wikipedia tool

In [1]:
import wikipediaapi
import wikipedia

wiki = wikipediaapi.Wikipedia('WikiReactGuidance (selfint@gmail.com)', 'en')

In [2]:
from typing import Callable
from transformers import PreTrainedTokenizerFast
import collections
from functools import partial

def Tree():
    return collections.defaultdict(Tree)

def build_tree(sections, tree = None) -> Tree:
    if tree is None:
        tree = Tree()

    for s in sections:
        tree[s.title] = (s.text, build_tree(s.sections))

    return dict(tree)


def chunk_tokens(
    tokens: list[str],
    chunk_size: int,
    chunk_overlap: int,
    is_subword: Callable[[str], bool],
) -> list[list[str]]:
    chunks = []

    while len(tokens) > 0:
        # reduce chunk size to not end with subword
        chunk_len = chunk_size
        while (chunk_len + 1 < len(tokens)) and is_subword(tokens[chunk_len + 1]):
            chunk_len -= 1
            assert chunk_len > 0, f"got empty chunk tokens left: {len(tokens)}"

        chunk_tokens = tokens[:chunk_len]
        chunks.append(chunk_tokens)

        new_start = chunk_len
        if new_start > len(tokens):
            break

        # try to reduce overlap to not cut subwords
        overlap = chunk_overlap if chunk_overlap < chunk_len else 0
        while is_subword(tokens[new_start - overlap]):
            overlap -= 1
            if overlap == 0:
                print("failed to prevent subword cut:", tokens[new_start - 1], tokens[new_start])
                break

        tokens = tokens[new_start - overlap:]

    return chunks


def build_text_chunks(
    text: str,
    title: str,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    is_subword: Callable[[str], bool],
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    # build header
    header = "\n".join(
        f"{'#' * (i + 1)} {t}" for i, t in enumerate(title_path + [title])
    ) + "\n"

    # split tokens into chunk sized chunks
    text_chunk_size = chunk_size - len(tokenizer.tokenize(header))
    text_chunks_tokens = chunk_tokens(
        tokenizer.tokenize(text),
        text_chunk_size,
        chunk_overlap,
        is_subword
    )

    # build chunks
    chunks = [
        {
            "text": header + tokenizer.decoder.decode(text_chunk),
            "meta": {
                "title": title,
                "title_path": title_path
            }
        }
        for text_chunk in text_chunks_tokens
    ]

    return chunks

def build_tree_chunks(
    tree: Tree,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    is_subword: Callable[[str], bool],
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    assert chunk_size > chunk_overlap * 2, f"overlap must be < {chunk_size // 2=}"
    _build_text_chunks = partial(
        build_text_chunks,
        tokenizer=tokenizer,
        is_subword=is_subword,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )

    chunks = []
    for k, (text, subtrees) in tree.items():
        chunks += _build_text_chunks(text, k, title_path)

        # build subtree chunks
        for subtitle, (subtext, subtree) in subtrees.items():
            chunks += _build_text_chunks(subtext, subtitle, title_path + [k])
            chunks += build_tree_chunks(
                subtree, title_path, tokenizer, is_subword, chunk_size, chunk_overlap,
            )

    return chunks

In [3]:
# used for bm25 search
import nltk
from nltk.corpus import stopwords
nltk.download("stopwords")
stop = set(stopwords.words("english"))

def preprocess(text: str) -> list[str]:
    return [
        w.strip() for w in text.lower().split(" ")
        if w.strip() not in stop
    ]

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/selfint/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


# Load Models

In [4]:
import torch
# change to "cuda" for nvidia GPUs, "mps" for Mac
torch.set_default_device("mps")

In [5]:
from FlagEmbedding import FlagReranker, FlagModel

use_fp16 = True
# in the end reranker was not used, but can be toggled on/off
reranker = FlagReranker('./models/BAAI/bge-reranker-large', use_fp16=use_fp16)
embedder = FlagModel(
    './models/BAAI/bge-large-en-v1.5',
    query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
    use_fp16=use_fp16
)

In [6]:
from transformers import AutoModelForCausalLM

# choose your model, make sure to adapt the guidance template
# if your model doesn't use ChatML
model_path = "./models/ehartford/dolphin-2.1-mistral-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path, device_map="auto")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Guidance

In [8]:
from guidance.llms import Transformers

# ChatML template, change to match your model's prompt template
class ChatMlModel(Transformers):
    llm_name = "chatml-model"

    @staticmethod
    def role_start(role, *args, **kwargs):
       return  f"<|im_start|>{role}\n"

    @staticmethod
    def role_end(role, *args, **kwargs):
        return "<|im_end|>"

In [9]:
import guidance

llm = ChatMlModel(model=model, tokenizer=tokenizer)
guidance.llm = llm

In [10]:
from functools import lru_cache
from collections import defaultdict
from typing import Literal
from rank_bm25 import BM25Okapi

@lru_cache
def wiki_search(query, pages=1):
    return wikipedia.search(query, pages, False)

@lru_cache
def wiki_open(page):
    return wiki.page(page)

class State:
    def __init__(
        self,
        tokenizer,
        is_subword,
        reranker,
        embedder,
        chunk_size: int = 64,
        chunk_overlap: int = 16,
        rerank: bool = False,
        lookup: Literal["sem"] | Literal["bm25"] = "sem",
    ) -> None:
        self.tokenizer = tokenizer
        self.is_subword = is_subword
        self.reranker = reranker
        self.embedder = embedder

        self.chunk_size = chunk_size
        self.chunk_overlap = chunk_overlap
        self.rerank = rerank
        assert lookup in ["sem", "bm25"], f"Invalid lookup method '{lookup}'. Use 'sem' / 'bm25'."
        self.lookup = {"sem": self.lookup_sem, "bm25": self.lookup_bm25}[lookup]

        self._reset()

    def _reset(self):
        self.page = None
        self.chunks = None
        self.embeddings = None
        self.bm25 = None
        self.history = defaultdict(lambda: defaultdict(int))

    def search(self, query: str) -> str:
        self._reset()

        results = wiki_search(query, 1)
        if len(results) == 0:
            return f"Failed. Article '{query}' does not exist."
        result = results[0]

        page = wiki_open(result)
        if page == self.page:
            return f"Re-opened '{page.title}'"

        if page.exists():
            self.page = page
            self.chunks = build_tree_chunks(
                build_tree(page.sections), [page.title], self.tokenizer, self.is_subword,
                chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
            )

            # first chunk before sections
            self.chunks.extend(
                build_text_chunks(
                    page.text.split("\n\n")[0], page.title, [], self.tokenizer, self.is_subword,
                    chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
                )
            )
            self.embeddings = self.embedder.encode(
                [c["text"] for c in self.chunks]
            )
            self.bm25 = BM25Okapi([preprocess(c["text"]) for c in self.chunks])

            summary = build_text_chunks(
                page.summary, page.title, [], self.tokenizer, self.is_subword,
                chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap
            )[0]["text"]
            return f"Most relevant article for '{query}' is: '{page.title}'. Summary:\n{summary}"
        else:
            return f"Failed. Article '{query}' does not exist."

    def lookup_sem(self, question: str, top_k: int = 1) -> str:
        if self.page is None:
            return "Can't search because no article is opened, please first open a article."

        # vector search
        q_embedding = self.embedder.encode_queries([question])
        relevance_scores = (q_embedding @ self.embeddings.T).squeeze()
        relevant_chunks = [a[0] for a in sorted(zip(self.chunks, relevance_scores), key=lambda a: a[1], reverse=True)]

        reranked_chunks = relevant_chunks
        if self.rerank:
            # rerank only the most relevant chunks
            relevant_chunks = relevant_chunks[:50]
            rerank_scores = self.reranker.compute_score(
                [[question, c["text"]] for c in relevant_chunks]
            )
            reranked_chunks = [a[0] for a in sorted(zip(self.chunks, rerank_scores), key=lambda a: a[1], reverse=True)]

        # give the 'next' result if the question is repeated
        n = self.history["ask"][question]
        self.history["ask"][question] += 1
        # remove this condition to simulate "scrolling"
        if n > 0:
            return "Duplicate search action detected, try opening a more relevant article!"
        best_chunks = "\n---\n".join(c["text"] for c in reranked_chunks[n:n+top_k])
        result = f"From page '{self.page.title}':\n{best_chunks}"
        return result

    def lookup_bm25(self, item: str, top_k: int = 1) -> str:
        # give the 'next' result if the question is repeated
        n = self.history["lookup"][item]
        self.history["lookup"][item] += 1
        # remove this condition to simulate "scrolling"
        if n > 0:
            return "Duplicate lookup action detected, try opening a more relevant article!"
        best_chunks = self.bm25.get_top_n(preprocess(item), self.chunks, n=n + top_k)[n:]
        best_chunks = "\n---\n".join(c["text"] for c in best_chunks)
        result = f"From page '{self.page.title}':\n{best_chunks}"
        return result

    def act(self, action: str, content: str) -> str:
        action = action.lower().strip()
        content = content.rstrip()
        match action:
            case "search":
                return self.search(content)
            case "lookup":
                return self.lookup(content)
            case _:
                raise ValueError(f"Invalid action '{action}'")

In [13]:
state = State(
    tokenizer=embedder.tokenizer, reranker=reranker, embedder=embedder,
    is_subword=lambda t: t.startswith("##") and t != "##",
    chunk_size=128,
    chunk_overlap=32,
    rerank=False
)

system_prompt = '''\
You are a helpful assistant with access to Wikipedia. Help the user with their \
requests as best your can. Use information from Wikipedia to support your \
response. Think critically and take intelligent actions to figure out the best \
possible response to the user.

You can interact with Wikipedia by performing these steps:
1. Thought: Share your thoughts while researching. Critique previous observations.
2. Act: Here you perform actions in Wikipedia. The syntax for performing an \
action is "<action>[<input>]". Available actions are:
    * Search: Receives a subject and opens and summarizes the most relevant Wikipedia article to \
that subject. Only one article can be opened at a time.
    * Lookup: Only works *after* an article is opened. Receives a question and \
returns the most relevant passage to the question from the currently opened article.
    * Infer: Receives what to infer and returns an attempt to infer it using only the previous \
thoughts and observations. Calculates step by step, as if in a test.
    * Finish: Receives a response based on the information you found, prefixed with a summary \
of all your thoughts and observations, and sends it to the user.
3. Observation: This is the result of your action, in a ``` block.

These steps repeat until you choose the 'Finish' action.'''

def finish_react_chain(chat: list[dict], react_chain: list[dict]):
    chat[-1]["react_chain"] = list(react_chain)
    react_chain.clear()
    state._reset()

react = guidance(
    log=True,
    act=state.act,
    max_iterations=4,
    finish_react_chain=finish_react_chain,
    template='''
{{~#geneach 'react_chain' max_iterations=max_iterations~}}
Thought {{add @index 1}}: \
{{gen 'this.thought' max_tokens=256 stop_regex="\.?\n?(\n|Act|Thought|Observation|```)"}}.
Act {{add @index 1}}:
{{~select 'this.action' logprobs='this.probs' options=[' Search', ' Lookup', ' Infer', ' Finish']~}}
{{~set 'last_action' this.action~}}
[
{{~#if this.action == ' Finish'~}}
Here is a summary of my thoughts and observations followed by my conclusion. {{/if~}}
{{gen 'this.input' max_tokens=256 stop="]"~}}
]
{{~#if this.action == ' Finish'~}}
    {{~set 'answer' this.input~}}
    {{~break~}}
{{~/if}}
Observation {{add @index 1}}:
```
{{~#if this.action == ' Infer'}}
{{gen 'this.result' max_tokens=256 stop_regex="\n*```"}}
{{~else}}
{{~set 'this.result' (act action=this.action content=this.input)}}
{{this.result}}
{{~/if}}
```
{{/geneach}}
{{~#if last_action != ' Finish'~}}
Act {{add max_iterations 1}}: Finish[Here is a summary of my thoughts and observations followed by my conclusion. \
{{gen 'answer' max_tokens=256 stop_regex="(\]|\n*(Act|Thought|Observation|```))"}}]
{{/if}}
{{~finish_react_chain chat=chat react_chain=react_chain~}}
'''
)

program = guidance(
    system_prompt=system_prompt,
    await_missing=True,
    log=True,
    react=react,
    # example=example,
    print_prop=lambda c, p: str(c[-1][p]),
    template='''\
{{#system~}}
{{system_prompt}}
{{~/system}}
{{#geneach 'chat'~}}
{{#user~}}
{{set 'this.user_msg' (await 'user_msg')~}}
{{print_prop c=chat p='user_msg'}}
{{~/user}}
{{#assistant~}}
{{#block hidden=True name='this.chain'~}}
{{>react}}
{{set 'this.assistant_msg' answer~}}
{{/block~}}
{{print_prop c=chat p='assistant_msg'}}
{{~/assistant}}
{{/geneach}}
'''
)

questions = [
    "Who won the war of the bucket?",
    "What is a quazar?",
    "How old was Julius Caesar when he died?",
    "How does a catapult work?",
    "Why is the sky blue?",
    "That sounds interesting, tell me more about it.",
]

In [19]:
# one question at a time
results = [
    program(silent=True)(user_msg=q)
    for q in questions
]

# all questions together in a "chat"
result = program(silent=False)
for q in questions:
    result = result(user_msg=q)

In [22]:
from scipy.special import softmax
from textwrap import wrap

def print_chat(chat):
    def pprint(text):
        print("\n".join(wrap(text, width=120, break_long_words=False, replace_whitespace=False)))

    for step in chat:
        if "assistant_msg" not in step:
            continue

        pprint(step["user_msg"])
        print("-" * 100)
        for thought in step["react_chain"]:
            pprint("Thought: " + thought["thought"])
            p = {
                k.lower().strip(): f"{v * 100:.1f}%"
                for k, v in zip(
                    thought["probs"].keys(),
                    softmax(list(thought["probs"].values()))
                )
            }
            pprint("Action: " + thought["action"] + "(" + str(p) + ")")
            pprint("Input: " + thought["input"])
            if "result" in thought:
                pprint("Observation: " + thought["result"])
            print("-" * 100)
        pprint(step["assistant_msg"])
        print("=" * 100)

In [23]:
for result in results:
    print_chat(result["chat"])

Who won the war of the bucket?
----------------------------------------------------------------------------------------------------
Thought: The user is asking about the "war of the bucket." I need to find out what this refers to
Action:  Search({'search': '99.9%', 'lookup': '0.1%', 'infer': '0.0%', 'finish': '0.0%'})
Input: War of the Bucket
Observation: Most relevant article for 'War of the Bucket' is: 'War of the Bucket'. Summary:
# War of the Bucket
the war
of the bucket or the war of the oaken bucket ( italian : guerra della secchia rapita ) was fought in 1325 between the
rival city - states of bologna and modena. it took place in the region of emilia - romagna, in northern italy. the war
was an episode in the over 300 - year - long struggle between guelphs and ghibellines. modena won the battle of
zappolino, the only battle of the war. a common myth surrounding the war of the bucket is that it was caused by the
modenese stealing a bucket from a bolog
-----------------------------

In [24]:
print_chat(result["chat"])

That sounds interesting, tell me more about it.
----------------------------------------------------------------------------------------------------
Thought: The user is interested in learning more about a specific topic
Action:  Search({'search': '99.3%', 'lookup': '0.3%', 'infer': '0.4%', 'finish': '0.0%'})
Input: Topic
Observation: Most relevant article for 'Topic' is: 'Topic'. Summary:
# Topic
topic, topics, topic, topical, or
topicality may refer to :
----------------------------------------------------------------------------------------------------
Thought: The user wants to know more about the topic
Action:  Lookup({'search': '0.3%', 'lookup': '99.1%', 'infer': '0.6%', 'finish': '0.0%'})
Input: Question about Topic
Observation: From page 'Topic':
# Topic
topic, topics, topic, topical, or topicality may refer to :
----------------------------------------------------------------------------------------------------
Thought: Based on the information provided, I can infer that the u