# ReAct implementation using guidance

# Wikipedia tool

In [1]:
import wikipediaapi
import wikipedia

wiki = wikipediaapi.Wikipedia('WikiReactGuidance (selfint@gmail.com)', 'en')

In [2]:
from typing import Callable
from transformers import PreTrainedTokenizerFast
import collections
from functools import partial

def Tree():
    return collections.defaultdict(Tree)

def build_tree(sections, tree = None) -> Tree:
    if tree is None:
        tree = Tree()

    for s in sections:
        tree[s.title] = (s.text, build_tree(s.sections))

    return dict(tree)


def chunk_tokens(
    tokens: list[str],
    chunk_size: int,
    chunk_overlap: int,
    is_subword: Callable[[str], bool],
) -> list[list[str]]:
    chunks = []

    while len(tokens) > 0:
        # reduce chunk size to not end with subword
        chunk_len = chunk_size
        while (chunk_len + 1 < len(tokens)) and is_subword(tokens[chunk_len + 1]):
            chunk_len -= 1
        assert chunk_len > 0, "got empty chunk"

        chunk_tokens = tokens[:chunk_len]
        chunks.append(chunk_tokens)

        new_start = chunk_len
        if new_start > len(tokens):
            break

        # try to reduce overlap to not cut subwords
        overlap = chunk_overlap if chunk_overlap < chunk_len else 0
        while is_subword(tokens[new_start - overlap]):
            overlap -= 1
            if overlap == 0:
                print("failed to prevent subword cut:", tokens[new_start - 1], tokens[new_start])
                break

        tokens = tokens[new_start - overlap:]

    return chunks


def build_text_chunks(
    text: str,
    title: str,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    # build header
    header = "\n".join(
        f"{'#' * (i + 1)} {t}" for i, t in enumerate(title_path + [title])
    ) + "\n"

    # split tokens into chunk sized chunks
    text_chunk_size = chunk_size - len(tokenizer.tokenize(header))
    text_chunks_tokens = chunk_tokens(
        tokenizer.tokenize(text),
        text_chunk_size,
        chunk_overlap,
        is_subword = lambda t: not t.startswith("▁")
    )

    # build chunks
    chunks = [
        {
            "text": header + tokenizer.decoder.decode(text_chunk),
            "meta": {
                "title": title,
                "title_path": title_path
            }
        }
        for text_chunk in text_chunks_tokens
    ]

    for c in chunks:
        if 'is documented as' in c["text"]:
            print(c["text"])

    return chunks

def build_tree_chunks(
    tree: Tree,
    title_path: list[str],
    tokenizer: PreTrainedTokenizerFast,
    chunk_size: int,
    chunk_overlap: int = 0,
) -> list[dict]:
    assert chunk_size > chunk_overlap * 2, f"overlap must be < {chunk_size // 2=}"
    _build_text_chunks = partial(
        build_text_chunks,
        tokenizer=tokenizer,
        chunk_size=chunk_size,
        chunk_overlap=chunk_overlap
    )

    chunks = []
    title_path = title_path or []

    for k, (text, subtrees) in tree.items():
        chunks += _build_text_chunks(text, k, title_path)

        # build subtree chunks
        for subtitle, (subtext, subtree) in subtrees.items():
            chunks += _build_text_chunks(subtext, subtitle, title_path + [k])
            chunks += build_tree_chunks(
                subtree, title_path, tokenizer, chunk_size, chunk_overlap,
            )

    return chunks

In [3]:
from FlagEmbedding import FlagReranker
import torch

torch.set_default_device("mps")
reranker = FlagReranker('./models/BAAI/bge-reranker-large')

In [4]:
from tqdm.auto import tqdm

def search_wikipedia(query, top_pages, top_chunks, chunk_size, chunk_overlap, thought = None):
    thought = thought or query
    results = wikipedia.search(query, top_pages, False)
    pages = [wiki.page(r) for r in results]
    chunks = []

    for page in tqdm(pages, desc="Chunking page contents"):
        tree = build_tree(page.sections)
        page_chunks = build_tree_chunks(
            tree, [page.title], reranker.tokenizer, chunk_size, chunk_overlap
        )
        for c in page_chunks:
            c["meta"]["page"] = page
        chunks += page_chunks

    scores = reranker.compute_score(
        [[query, c["text"]] for c in chunks]
    )

    best = [a[0] for a in sorted(zip(chunks, scores), key=lambda a: a[1], reverse=True)]
    return best[:top_chunks]

In [5]:
from textwrap import wrap

def print_chunks(chunks):
    print(
        "\n--\n".join(
            "\n".join(
                wrap(
                    c["text"],
                    width=100,
                    break_long_words=False,
                    replace_whitespace=False,
                )
            ) for c in chunks
        )
    )

# Load Model

In [6]:
from transformers import AutoModelForCausalLM
import torch

model_path = "./models/ehartford/dolphin-2.1-mistral-7b"
model = AutoModelForCausalLM.from_pretrained(
    model_path,
    device_map="auto",
    torch_dtype=torch.float16,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [7]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_path, device_map="auto")

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


# Guidance

In [8]:
from guidance.llms import Transformers

class ChatMlModel(Transformers):
    llm_name = "chatml-model"

    @staticmethod
    def role_start(role, *args, **kwargs):
       return  f"<|im_start|>{role}\n"

    @staticmethod
    def role_end(role, *args, **kwargs):
        return "<|im_end|>"

In [9]:
import guidance

llm = ChatMlModel(model=model, tokenizer=tokenizer)
guidance.llm = llm

In [10]:
from functools import lru_cache

@lru_cache
def wiki_search(query, pages=1):
    return wikipedia.search(query, pages, False)

@lru_cache
def wiki_open(page):
    return wiki.page(page)

class State:
    def __init__(self, reranker) -> None:
        self.page = None
        self.chunks = None
        self.reranker = reranker

    def open(self, query: str) -> str:
        self.page = None

        results = wiki_search(query, 1)
        if len(results) == 0:
            return f"No entry found for '{query}'"
        result = results[0]

        page = wiki_open(result)
        if page == self.page:
            return f"Loaded page '{page.title}', summary:\n{page.summary}"

        if page.exists():
            self.page = page
            self.chunks = build_tree_chunks(
                build_tree(page.sections), [page.title], tokenizer,
                chunk_size=256, chunk_overlap=64
            )
            return f"Loaded page '{page.title}', summary:\n{page.summary}"
        else:
            return f"No entry found for '{query}'"

    def search(self, query: str) -> str:
        results = wiki_search(query, 1)
        if len(results) == 0:
            return f"No entry found for '{query}'"

        result = results[0]
        page = wiki_open(result)

        if page.exists():
            page = page
            chunks = build_tree_chunks(
                build_tree(page.sections), [page.title], tokenizer,
                chunk_size=256, chunk_overlap=64
            )
            scores = self.reranker.compute_score(
                [[query, c["text"]] for c in chunks]
            )

            best = [a[0] for a in sorted(zip(chunks, scores), key=lambda a: a[1], reverse=True)]
            return best[0]["text"]
        else:
            return f"No entry found for '{query}'"

    def lookup(self, query: str) -> str:
        if self.page is None:
            return "Can't lookup because no page is opened, please first open a page."

        scores = self.reranker.compute_score(
            [[query, c["text"]] for c in self.chunks]
        )

        best = [a[0] for a in sorted(zip(self.chunks, scores), key=lambda a: a[1], reverse=True)]
        return best[0]["text"]

    def act(self, action: str, content: str) -> str:
        action = action.lower().strip()
        content = content.rstrip()
        print(f"{action=}")
        print(f"{content=}")
        match action:
            case "open":
                return self.open(content)
            case "lookup":
                return self.lookup(content)
            case "search":
                return self.search(content)
            case _:
                raise ValueError(f"Invalid action '{action}'")

In [27]:
state = State(reranker=reranker)

system_prompt = '''\
You are WikiBot, an assistant that can answer questions about the world \
by using information available in Wikipedia, and by explaining your thoughts \
and actions to the user.

Given a question, here are the stages of your response:
1. Think: Explain to the user your thought process behind how you are trying to \
figure out the answer. Decide which action you are going to perform next, think step by step.
2. Act: Here you perform actions. The syntax for an action is "<action>[<input>]". \
Available actions are:
    * Open: Open a Wikipedia page will also return the summary of the page. \
For example: `Open[George Washington]`.
    * Lookup: Only works *after* a page is opened, use this action to find information within \
the opened page. For example, *after* `Open[George Washington]`, you could then `Lookup[Date of birth]`.
    * Finish: When you are confident you can answer the user question, in its \
entirety, you Finish your response with the answer. For example: `Finish[The first \
president of the United States is George Washington. He was elected in 1789.]`
3. Observation: This is the result of your action.

These stages repeat until you choose the Finish action and provide the final \
answer. Only decide to Finish when you have figured out everything you need in \
order to answer the question. Make sure to reflect on your progress after each observation, \
and make any necessary adjustments to your plan. Make sure that when you can answer the question \
you always use the Finish action.'''

react = guidance(
    log=True,
    act=state.act,
    template='''\
{{#geneach 'react_chain' stop=False~}}
Thought: {{gen 'this.thought' max_tokens=512 stop_regex="\n?\nAct"}}
Act:{{select 'this.act_action' options=[' Open', ' Lookup', ' Finish']}}[\
{{gen 'this.act_input' max_tokens=512 stop="]"}}\
]
{{#if this.act_action != ' Finish'~}}
Observation:
```
{{set hidden=True 'this.observation' (act action=this.act_action content=this.act_input)~}}
{{this.observation}}
```
{{else~}}
{{set 'answer' this.act_input~}}
{{break~}}
{{/if~}}
{{/geneach~}}'''
)

program = guidance(
    system_prompt=system_prompt,
    await_missing=True,
    react=react,
    print_prop=lambda c, p: str(c[-1][p]),
    template='''\
{{#system~}}
{{system_prompt}}
{{~/system}}
{{#geneach 'chat'~}}
{{#user~}}
{{set 'this.user_msg' (await 'user_msg')~}}
{{print_prop c=chat p='user_msg'}}
{{~/user}}
{{#assistant~}}
{{#block hidden=True name='this.chain'~}}
{{>react}}
{{set 'this.assistant_msg' answer~}}
{{/block~}}
{{print_prop c=chat p='assistant_msg'}}
{{~/assistant}}
{{/geneach}}
'''
)

r = program(silent=False)\
    (user_msg="Who was the first US president?")\
    (user_msg="When was he born?")

r['chat']

[{'user_msg': 'Who was the first US president?',
  'assistant_msg': 'The first president of the United States was George Washington.',
  'chain': 'Thought: I need to find out who the first president of the United States was. I will open the Wikipedia page for the United States to gather information.\nAct: Open[United States]\nObservation:\n```\nLoaded page \'United States\', summary:\nThe United States of America (USA), commonly known as the United States (U.S.) or simply America, is a country primarily located in North America and consisting of 50 states, a federal district, five major unincorporated territories, and nine Minor Outlying Islands. It includes 326 Indian reservations. It is the world\'s third-largest country by both land and total area. It shares land borders with Canada to its north and with Mexico to its south and has maritime borders with the Bahamas, Cuba, Russia, and other nations. With a population of over 333 million, it is the most populous country in the America

In [30]:
r = r(user_msg="Who won the battle of Pharsalus?")

In [33]:
print(r["chat"][2]["chain"])

Thought: I need to look up the Battle of Pharsalus to find out who won it.
Act: Open[Battle of Pharsalus]
Observation:
```
Loaded page 'Battle of Pharsalus', summary:
The Battle of Pharsalus was the decisive battle of Caesar's Civil War fought on 9 August 48 BC near Pharsalus in Central Greece. Julius Caesar and his allies formed up opposite the army of the Roman Republic under the command of Pompey. Pompey had the backing of a majority of Roman senators and his army significantly outnumbered the veteran Caesarian legions.
Pressured by his officers, Pompey reluctantly engaged in battle and suffered an overwhelming defeat, ultimately fleeing the camp and his men, disguised as an ordinary citizen. Eventually making his way to Egypt, he was assassinated upon his arrival at the order of Ptolemy XIII.
```
Thought: From the summary, it seems that Julius Caesar won the battle of Pharsalus.
Act: Lookup[Winner of the Battle of Pharsalus]
Observation:
```
# Battle of Pharsalus
## Battle
Caesar u