# Search-R1

This is a model trained to generated queries as part of its inference process. The original implementation is from infer.py in https://github.com/PeterGriffinJin/Search-R1 by the original authors.


In [1]:
#%pip install -q python-terrier accelerate

In [2]:
import pyterrier as pt
import pyterrier_rag

## Retrieval Setup

Lets get a BM25 retriever. This (Terrier) retriever also has the 'text', 'title' metadata for passages.

In [3]:
sparse_index = pt.Artifact.from_hf('pyterrier/ragwiki-terrier')

# queries from R1 may have tokens that Terrier doesnt like. We can remove them and put them back later.
bm25 = pt.rewrite.tokenise() >> sparse_index.bm25(include_fields=['docno', 'text', 'title']) >> pt.rewrite.reset()


Java started (triggered by tokenise) and loaded: pyterrier.java, pyterrier.terrier.java [version=5.11 (build: craig.macdonald 2025-01-13 21:29), helper_version=0.0.8]


21:54:50.500 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading lookup file directly from disk (SLOW) - try index.meta.index-source=fileinmem in the index properties file. 160.3 MiB of memory would be required.
21:54:50.515 [main] WARN org.terrier.structures.BaseCompressingMetaIndex -- Structure meta reading data file directly from disk (SLOW) - try index.meta.data-source=fileinmem in the index properties file. 8.2 GiB of memory would be required.


## Search-R1 model

We invoke SearchR1 using our BM25 retrieval pipeline. By default, SearchR1 takes only 3 passages from the specified retriever.

In [4]:
r1_bm25 = pyterrier_rag.SearchR1(bm25)

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

Lets try it out. We get back a dataframe with one row, which has the generated answer in the qanswer column.

In [5]:
r1_bm25.search("what are chemical reactions?").iloc[0]

qid                                                            1
query                               what are chemical reactions?
qanswer        chemical transformation of one set of chemical...
output         <think>I found out that chemical reactions are...
iteration                                                      1
all_queries                 [(0,  what are chemical reactions )]
Name: 0, dtype: object

## Improving the Retriever

As SearchR1 takes only the top 3 passages, the precision is very important. Lets rerank the top 20 passages using the MonoT5 cross-encoder.

In [8]:
from pyterrier_t5 import MonoT5ReRanker
monoT5 = MonoT5ReRanker()
r1_monoT5 = r1_bm25.clone_for_retriever(bm25 % 20 >> monoT5)
r1_monoT5.search("what are chemical reactions?").iloc[0]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
monoT5: 100%|██████████| 5/5 [00:00<00:00, 52.99batches/s]


qid                                                            1
query                               what are chemical reactions?
qanswer        chemical transformation of one set of chemical...
output         <think>I found out that chemical reactions are...
iteration                                                      1
all_queries                 [(0,  what are chemical reactions )]
Name: 0, dtype: object

In this case, it seems the answer stayed the same. But lets see if mattered quantitively...

## Evaluation

Now lets run a quick experiment using Natural Questions, comparing our two R1 invocations. I'm also going to add a custom measure to see how many (search/thought) iterations were used by the two settings.

In [18]:
dataset = pt.get_dataset('rag:nq')
from ir_measures import define_byquery
Iterations = define_byquery(lambda qrels, run: run.iloc[0].iteration, name="Iterations")
pt.Experiment(
    [r1_bm25, r1_monoT5],
    dataset.get_topics('dev').head(2), # NB: remove .head(100) to run on all dev topics
    dataset.get_answers('dev'),
    [pyterrier_rag.measures.F1, pyterrier_rag.measures.EM, Iterations],
    batch_size=25,
    verbose=True,
    names=['R1(BM25)', 'R1(monoT5)']
)

pt.Experiment:  50%|█████     | 1/2 [00:12<00:12, 12.73s/batches]
monoT5: 100%|██████████| 5/5 [00:00<00:00, 51.11batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 51.66batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 50.90batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 52.87batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 52.82batches/s]

monoT5: 100%|██████████| 5/5 [00:00<00:00, 53.04batches/s]

monoT5:   0%|          | 0/5 [00:00<?, ?batches/s][A
monoT5: 100%|██████████| 5/5 [00:00<00:00, 49.14batches/s][A
pt.Experiment: 100%|██████████| 2/2 [00:25<00:00, 12.95s/batches]


Unnamed: 0,name,F1,EM,Iterations
0,R1(BM25),0.333333,0.0,3.5
1,R1(monoT5),0.5,0.5,3.5


## What about Dense Retrieval

Don't fear, there is a dense index for wiki available.... Instructions coming soon.