# Search-O1

This is a prompt based method to generate queries as part of inference.

In [1]:
#%pip install -q python-terrier accelerate pyterrier_t5
#%pip install -q git+https://github.com/terrierteam/pyterrier_rag.git

In [2]:
import pyterrier as pt
import pyterrier_rag
import torch

## Retrieval Setup

Lets get a BM25 retriever. This (Terrier) retriever also has the 'text', 'title' metadata for passages.

In [3]:
sparse_index = pt.Artifact.from_hf('pyterrier/ragwiki-terrier')

# queries from R1 may have tokens that Terrier doesnt like. We can remove them and put them back later.
bm25 = pt.rewrite.tokenise() >> sparse_index.bm25(include_fields=['docno', 'text', 'title']) >> pt.rewrite.reset()


Java started (triggered by tokenise) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]


14:18:10.735 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading lookup file directly from disk (SLOW) - try index.meta.index-source=fileinmem in the index properties file. 160.3 MiB of memory would be required.
14:18:10.748 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 8.2 GiB of memory would be required.


# Reader Setup

We're going to use a DeepSeek R1 model

In [4]:
model_args = {"torch_dtype": torch.bfloat16, "trust_remote_code": True}
generator = pyterrier_rag.readers.CausalLMReader(
    "deepseek-ai/DeepSeek-R1-Distill-Llama-8B", # Or use this larger model for better answers... "deepseek-ai/DeepSeek-R1-Distill-Qwen-14B", 
    model_args = model_args, 
    text_max_length=4096, 
    max_new_tokens=2048,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## Search-O1 model

We invoke SearchR1 using our BM25 retrieval pipeline. By default, SearchR1 takes only 3 passages from the specified retriever.

In [5]:
o1_bm25 = pyterrier_rag.SearchO1(bm25, generator)

Lets try it out. We get back a dataframe with one row, which has the generated answer in the qanswer column.

In [6]:
res = o1_bm25.search("what are chemical reactions?")
res

Unnamed: 0,qid,question,prompt,output,finished,history,search_count,search_queries,retrieval_results,qanswer
0,1,what are chemical reactions?,You are a reasoning assistant with the ability...,"Okay, so I need to figure out what chemical re...",True,"[Okay, so I need to figure out what chemical r...",0,{},{},


We can see the answer in the qanswer column

In [7]:
res.iloc[0].qanswer

''

The all_queries column shows the queries that were passed to the search engine

In [8]:
res.iloc[0].search_queries

set()

Finally, we can also see the full output of the model, including its reasoning, its generated query, the retrieved documents, and the generated final answer...

In [9]:
print(res.iloc[0].output)

Okay, so I need to figure out what chemical reactions are. Let me start by recalling what I know from school. Chemical reactions involve the interaction between different chemicals, right? But I'm a bit fuzzy on the details, so I should break it down.

First, I remember that a chemical reaction involves the transformation of chemical substances into other chemical substances. But what exactly happens in that process? I think it's something to do with atoms and molecules swapping atoms or changing their arrangement. Oh, right, it's called chemical bonding.

So, when two or more substances react, they form new substances. But how do I describe that reaction? I think it's written with an arrow, like A → B, where A is the reactant and B is the product. Wait, but sometimes there are multiple reactants and products. For example, in a balanced equation, you have a + b → c + d, where a, b, c, and d are the amounts of each substance.

Now, what makes a reaction happen? I remember something abou

## Improving the Retriever

As SearchO1 takes only the top 3 passages, the precision is very important. Lets rerank the top 20 passages using the MonoT5 cross-encoder.

In [10]:
from pyterrier_t5 import MonoT5ReRanker
monoT5 = MonoT5ReRanker()
o1_monoT5 = pyterrier_rag.SearchO1(bm25 % 20 >> monoT5, generator)
o1_monoT5.search("what are chemical reactions?").iloc[0]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


qid                                                                  1
question                                  what are chemical reactions?
prompt               You are a reasoning assistant with the ability...
output               Okay, so I need to figure out what chemical re...
finished                                                          True
history              [Okay, so I need to figure out what chemical r...
search_count                                                         0
search_queries                                                      {}
retrieval_results                                                   {}
qanswer                                                               
Name: 0, dtype: object

So lets see if using monoT5 quantitively improved the results...

## Evaluation

Now lets run a quick experiment using Natural Questions, comparing our two O1 invocations. I'm also going to add a custom measure to see how many search iterations were used by the two settings.

In [11]:
dataset = pt.get_dataset('rag:nq')
from ir_measures import define_byquery
Iterations = define_byquery(lambda qrels, run: run.iloc[0].search_count if "search_count" in run.columns else 0, name="Iterations")
pt.Experiment(
    [o1_bm25, o1_monoT5],
    dataset.get_topics('dev').head(100), # NB: remove .head(100) to run on all dev topics
    dataset.get_answers('dev'),
    [pyterrier_rag.measures.F1, pyterrier_rag.measures.EM, Iterations],
    batch_size=25,
    verbose=True,
    names=['O1-search(BM25, Deepseek-R1)', 'O1-search(monoT5, Deepseek-R1)']
)

pt.Experiment:   0%|          | 0/8 [00:00<?, ?batches/s]

The maximum number of turns 10 is exceeded, stopping...
The maximum number of turns 10 is exceeded, stopping...


pt.Experiment:  50%|█████     | 4/8 [25:56<24:43, 370.96s/batches]
monoT5:   0%|          | 0/5 [00:00<?, ?batches/s][APassing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.
monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.11batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.38batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.62batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.14batches/s]
pt.Experiment:  75%|███████▌  | 6/8 [35:38<10:45, 322.88s/batches]
monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.71batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.96batches/s]
pt.Experiment:  88%|████████▊ | 7/8 [40:48<05:18, 318.40s/batches]
monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.35batches/s]
pt.Experiment: 100%|██████████| 8/8 [44:56<00:00, 337.10s/batches]


Unnamed: 0,name,F1,EM,Iterations
0,"O1-search(BM25, Deepseek-R1)",0.211728,0.15,0.23
1,"O1-search(monoT5, Deepseek-R1)",0.221522,0.12,0.07


So here, using monoT5 over BM25 improves the answer quality (in terms of F1 but not EM), and slightly reduces the number of search iterations...

## What about Dense Retrieval?

Don't fear, there is a dense index for wiki available.... Instructions coming soon.