# ReAct implementation using guidance

# Wikipedia tool

In [1]:
import wikipediaapi
import wikipedia

wiki = wikipediaapi.Wikipedia('WikiReactGuidance (selfint@gmail.com)', 'en')

In [2]:
from typing import Callable
from transformers import PreTrainedTokenizerFast
import collections
from functools import partial

def Tree():
    return collections.defaultdict(Tree)

def build_tree(sections, tree = None) -> Tree:
    if tree is None:
        tree = Tree()

    for s in sections:
        tree[s.title] = (s.text, build_tree(s.sections))

    return dict(tree)


def chunk_tokens(
    tokens: list[str],
    chunk_size: int,
    chunk_overlap: int,
    is_subword: Callable[[str], bool],
) -> list[list[str]]:
    chunks = []

    while len(tokens) > 0:
        # reduce chunk size to not end with subword
        chunk_len = chunk_size
        while (chunk_len + 1 < len(tokens)) and is_subword(tokens[chunk_len + 1]):
            chunk_len -= 1
        assert chunk_len > 0, "got empty chunk"

        chunk_tokens = tokens[:chunk_len]
        chunks.append(chunk_tokens)

        new_start = chunk_len
        if new_start > len(tokens):
            break

        # try to reduce overlap to not cut subwords
        overlap = chunk_overlap if chunk_overlap < chunk_len else 0
        while is_subword(tokens[new_start - overlap]):
            overlap -= 1
            if overlap == 0:
                print("failed to prevent subword cut:", tokens[new_start - 1], tokens[new_start])
                break

        tokens = tokens[new_start - overlap:]

    return chunks


def build_text_chunks(
    text: str,
    title: str,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    # build header
    header = "\n".join(
        f"{'#' * (i + 1)} {t}" for i, t in enumerate(title_path + [title])
    ) + "\n"

    # split tokens into chunk sized chunks
    text_chunk_size = chunk_size - len(tokenizer.tokenize(header))
    text_chunks_tokens = chunk_tokens(
        tokenizer.tokenize(text),
        text_chunk_size,
        chunk_overlap,
        is_subword = lambda t: not t.startswith("▁")
    )

    # build chunks
    chunks = [
        {
            "text": header + tokenizer.decoder.decode(text_chunk),
            "meta": {
                "title": title,
                "title_path": title_path
            }
        }
        for text_chunk in text_chunks_tokens
    ]

    return chunks

def build_tree_chunks(
    tree: Tree,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    assert chunk_size > chunk_overlap * 2, f"overlap must be < {chunk_size // 2=}"
    _build_text_chunks = partial(
        build_text_chunks,
        tokenizer=tokenizer,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )

    chunks = []
    for k, (text, subtrees) in tree.items():
        chunks += _build_text_chunks(text, k, title_path)

        # build subtree chunks
        for subtitle, (subtext, subtree) in subtrees.items():
            chunks += _build_text_chunks(subtext, subtitle, title_path + [k])
            chunks += build_tree_chunks(
                subtree, title_path, tokenizer, chunk_size, chunk_overlap,
            )

    return chunks

In [3]:
# used for bm25 search
import nltk
from nltk.corpus import stopwords
nltk.download("stopwords")
stop = set(stopwords.words("english"))

def preprocess(text: str) -> list[str]:
    return [
        w.strip() for w in text.lower().split(" ")
        if w.strip() not in stop
    ]

[nltk_data] Downloading package stopwords to
[nltk_data]     /Users/selfint/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


# Load Models

In [4]:
import torch
# change to "cuda" for nvidia GPUs, "mps" for Mac
torch.set_default_device("mps")

In [5]:
from FlagEmbedding import FlagReranker, FlagModel

use_fp16 = True
# in the end reranker was not used, but can be toggled on/off
reranker = FlagReranker('./models/BAAI/bge-reranker-large', use_fp16=use_fp16)
embedder = FlagModel(
    './models/BAAI/bge-large-en-v1.5',
    query_instruction_for_retrieval="Represent this sentence for searching relevant passages: ",
    use_fp16=use_fp16
)

In [6]:
from transformers import AutoModelForCausalLM

# choose your model, make sure to adapt the guidance template
# if your model doesn't use ChatML
model_path = "./models/ehartford/dolphin-2.1-mistral-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path, device_map="auto")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Guidance

In [8]:
from guidance.llms import Transformers

# ChatML template, change to match your model's prompt template
class ChatMlModel(Transformers):
    llm_name = "chatml-model"

    @staticmethod
    def role_start(role, *args, **kwargs):
       return  f"<|im_start|>{role}\n"

    @staticmethod
    def role_end(role, *args, **kwargs):
        return "<|im_end|>"

In [9]:
import guidance

llm = ChatMlModel(model=model, tokenizer=tokenizer)
guidance.llm = llm

In [11]:
from functools import lru_cache
from collections import defaultdict
from rank_bm25 import BM25Okapi

@lru_cache
def wiki_search(query, pages=1):
    return wikipedia.search(query, pages, False)

@lru_cache
def wiki_open(page):
    return wiki.page(page)

class State:
    def __init__(self, tokenizer, reranker, embedder) -> None:
        self.tokenizer = tokenizer
        self.reranker = reranker
        self.embedder = embedder

        self._reset()

    def _reset(self):
        self.page = None
        self.chunks = None
        self.embeddings = None
        self.bm25 = None
        self.history = defaultdict(lambda: defaultdict(int))

    def search(self, query: str, chunk_size: int = 128, chunk_overlap: int = 32) -> str:
        self._reset()

        results = wiki_search(query, 1)
        if len(results) == 0:
            return f"Failed. Article '{query}' does not exist."
        result = results[0]

        page = wiki_open(result)
        if page == self.page:
            return f"Success. Article '{page.title}' can now be used."

        if page.exists():
            self.page = page
            self.chunks = build_tree_chunks(
                build_tree(page.sections), [page.title], self.tokenizer,
                chunk_size=chunk_size, chunk_overlap=chunk_overlap
            )

            # first chunk before sections
            self.chunks.extend(
                build_text_chunks(
                    page.text.split("\n\n")[0], page.title, [], self.tokenizer,
                    chunk_size=chunk_size, chunk_overlap=chunk_overlap
                )
            )
            self.embeddings = self.embedder.encode(
                [c["text"] for c in self.chunks]
            )
            self.bm25 = BM25Okapi([preprocess(c["text"]) for c in self.chunks])

            return f"Article '{page.title}' can now be used."
        else:
            return f"Failed. Article '{query}' does not exist."

    def lookup_sem(self, question: str, top_k: int = 1, rerank: bool = False) -> str:
        if self.page is None:
            return "Can't search because no article is opened, please first open a article."

        # vector search
        q_embedding = self.embedder.encode_queries([question])
        relevance_scores = (q_embedding @ self.embeddings.T).squeeze()
        relevant_chunks = [a[0] for a in sorted(zip(self.chunks, relevance_scores), key=lambda a: a[1], reverse=True)]

        reranked_chunks = relevant_chunks
        if rerank:
            # rerank only the most relevant chunks
            relevant_chunks = relevant_chunks[:50]
            rerank_scores = self.reranker.compute_score(
                [[question, c["text"]] for c in relevant_chunks]
            )
            reranked_chunks = [a[0] for a in sorted(zip(self.chunks, rerank_scores), key=lambda a: a[1], reverse=True)]

        # give the 'next' result if the question is repeated
        n = self.history["ask"][question]
        self.history["ask"][question] += 1
        # remove this condition to simulate "scrolling"
        if n > 0:
            return "Duplicate search action detected, try opening a more relevant article!"
        best_chunks = "\n---\n".join(c["text"] for c in reranked_chunks[n:n+top_k])
        result = f"From page '{self.page.title}':\n{best_chunks}"
        return result

    def lookup_bm25(self, item: str, top_k: int = 1) -> str:
        # give the 'next' result if the question is repeated
        n = self.history["lookup"][item]
        self.history["lookup"][item] += 1
        # remove this condition to simulate "scrolling"
        if n > 0:
            return "Duplicate lookup action detected, try opening a more relevant article!"
        best_chunks = self.bm25.get_top_n(preprocess(item), self.chunks, n=n + top_k)[n:]
        best_chunks = "\n---\n".join(c["text"] for c in best_chunks)
        result = f"From page '{self.page.title}':\n{best_chunks}"
        return result

    def act(self, action: str, content: str) -> str:
        action = action.lower().strip()
        content = content.rstrip()
        match action:
            case "search":
                return self.search(content)
            case "lookup":
                return self.lookup_sem(content)
            case "_lookup":
                return self.lookup_bm25(content)
            case _:
                raise ValueError(f"Invalid action '{action}'")

In [15]:
state = State(tokenizer=tokenizer, reranker=reranker, embedder=embedder)

system_prompt = '''\
You are a helpful assistant with access to Wikipedia. Help the user with their \
requests as best your can. Use information from Wikipedia to support your \
response. Think critically and take intelligent actions to figure out the best \
possible response to the user.

You can interact with Wikipedia by performing these steps:
1. Think: Think about these things:
    * What should your response look like?
    * What information is needed for a well sources response?
    * What information was already found in previous searches?
    * Which action would be the most effective to find the missing information?
2. Act: Here you perform actions in Wikipedia. The syntax for performing an \
action is "<action>[<input>]". Available actions are:
    * Search: Receives a query and opens the most relevant Wikipedia article to \
that query. Only one article can be opened at a time.
    * Lookup: Only works *after* an article is opened. Receives a question and \
returns the most relevant passage to the question from the currently opened article.
    * Finish: Receives a response based on the information you found and sends it to the user.
3. Observation: This is the result of your action, in a ``` block.

These steps repeat until you choose the 'Finish' action.'''

react = guidance(
    log=True,
    act=state.act,
    max_iterations=4,
    template='''
{{~#geneach 'react_chain' max_iterations=max_iterations~}}
Think {{add @index 1}}: \
{{gen 'this.thought' max_tokens=256 stop_regex="\.?\n?\n(Act|Think|Result|```)"}}.
Act {{add @index 1}}:
{{~select 'this.action' logprobs='this.probs' options=[' Search', ' Lookup', ' Finish']~}}
{{~set 'last_action' this.action~}}
[{{gen 'this.input' max_tokens=256 stop="]"}}]
{{~#if this.action == ' Finish'~}}
    {{~set 'answer' this.input~}}
    {{~break~}}
{{~/if}}
Observation {{add @index 1}}:
```
{{set 'this.result' (act action=this.action content=this.input)~}}
{{this.result}}
```
{{/geneach}}
{{~#if last_action != ' Finish'~}}
Act {{add max_iterations 1}}: Finish[Based on the information I found (note that I expanded \
my calculations step by step for clarity). \
{{gen 'answer' max_tokens=256 stop_regex="(\]|\n*(Act|Think|Result|```))"}}]
{{/if}}
{{~set 'chat[-1].react_chain' react_chain}}
'''
)

program = guidance(
    system_prompt=system_prompt,
    await_missing=True,
    log=True,
    react=react,
    print_prop=lambda c, p: str(c[-1][p]),
    template='''\
{{#system~}}
{{system_prompt}}
{{~/system}}
{{#geneach 'chat'~}}
{{#user~}}
{{set 'this.user_msg' (await 'user_msg')~}}
{{print_prop c=chat p='user_msg'}}
{{~/user}}
{{#assistant~}}
{{#block hidden=True name='this.chain'~}}
{{>react}}
{{set 'this.assistant_msg' answer~}}
{{/block~}}
{{print_prop c=chat p='assistant_msg'}}
{{~/assistant}}
{{/geneach}}
'''
)

result = (program(silent=False)
     (user_msg="Who won the war of the bucket?")
     (user_msg="What is a quazar?")
     (user_msg="How does a catapult work?")
     (user_msg="How old was Caesar when he died?")
     (user_msg="Why is the sky blue?")
     (user_msg="That sounds interesting, tell me more about it.")
    )

In [16]:
from scipy.special import softmax
from textwrap import wrap

def print_chat(chat):
    def pprint(text):
        print("\n".join(wrap(text, width=120, break_long_words=False, replace_whitespace=False)))

    for step in chat:
        if "assistant_msg" not in step:
            continue

        pprint(step["user_msg"])
        print("-" * 100)
        for thought in step["react_chain"]:
            pprint("Thought: " + thought["thought"])
            p = {
                k.lower().strip(): f"{v * 100:.1f}%"
                for k, v in zip(
                    thought["probs"].keys(),
                    softmax(list(thought["probs"].values()))
                )
            }
            pprint("Action: " + thought["action"] + "(" + str(p) + ")")
            pprint("Input: " + thought["input"])
            if "result" in thought:
                pprint("Result: " + thought["result"])
            print("-" * 100)
        pprint(step["assistant_msg"])
        print("=" * 100)

In [17]:
print_chat(r["chat"])

Who won the war of the bucket?
----------------------------------------------------------------------------------------------------
Thought: The user is asking about the War of the Bucket, which seems to be a historical event or conflict. To provide a
well-sourced response, I need to find information about this event and its outcome
Action:  Search({'search': '99.8%', 'lookup': '0.2%', 'finish': '0.0%'})
Input: War of the Bucket
Result: Article 'War of the Bucket' can now be used.
----------------------------------------------------------------------------------------------------
Thought: Now that I have opened the article, I can use the Lookup action to find information about the winner of the war
Action:  Lookup({'search': '0.0%', 'lookup': '100.0%', 'finish': '0.0%'})
Input: Who won the War of the Bucket?
Result: From page 'War of the Bucket':
# War of the Bucket
The War of the Bucket or the War of the Oaken Bucket
(Italian: Guerra della secchia rapita) was fought in 1325 between th