Install python-terrier and other libs

In [1]:
pip install python-terrier==0.10.0 fast-forward-indexes==0.2.0

Note: you may need to restart the kernel to use updated packages.


Configure PyTerrier

In [1]:
import pyterrier as pt

if not pt.started():
    pt.init(
        tqdm="notebook",
        boot_packages=["com.github.terrierteam:terrier-prf:-SNAPSHOT"]
    )

PyTerrier 0.10.0 has loaded Terrier 5.9 (built by craigm on 2024-05-02 17:40) and terrier-helper 0.0.8

No etc/terrier.properties, using terrier.default.properties for bootstrap configuration.


Import Dataset

In [2]:
dataset = pt.get_dataset("irds:beir/trec-covid")

Create a lexical index (for `BM25` and `RM3`)

In [21]:
from pathlib import Path

idx_path = Path("indices/covid_trac_idx_blocks").absolute()

index_ref = pt.index.IterDictIndexer(
    str(idx_path),
    blocks=True,
    # stopwords=None,
    # stemmer=None,
).index(dataset.get_corpus_iter(), fields=['text', 'title', 'url', 'pubmed_id'])

beir/trec-covid documents:   0%|          | 0/171332 [00:00<?, ?it/s]

ValueError: Index already exists at /Users/tomighita/Scoala/Facultate/University-Courses/RP/indices/covid_trac_idx_blocks/data.properties

Create a baseline (BM25 performance)

In [5]:
from pyterrier.measures import RR, nDCG, MAP

index = pt.IndexFactory.of(str(idx_path))

bm25 = pt.BatchRetrieve(index, wmodel="BM25")
rm3 = pt.rewrite.RM3(index)
testset = dataset
pt.Experiment(
    [bm25],
    dataset.get_topics('text'),
    dataset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 100],
)



Unnamed: 0,name,RR@10,nDCG@10,AP@100
0,BR(BM25),0.886857,0.640083,0.085416


# Create the Fast-Forward Indices for TCT-ColBERT

In [6]:
(bm25 % 5)(testset.get_topics('text')) #candidates

Unnamed: 0,qid,docid,docno,rank,score,query
0,1,81848,dv9m19yk,0,12.553108,what is the origin of covid 19
1,1,103419,kgifmjvb,1,12.462951,what is the origin of covid 19
2,1,123191,wmfcey6f,2,12.010235,what is the origin of covid 19
3,1,67367,4dtk1kyh,3,11.783026,what is the origin of covid 19
4,1,106832,cniyembt,4,11.665801,what is the origin of covid 19
...,...,...,...,...,...,...
49000,50,117999,xbze5s3c,0,31.931757,what is known about an mrna vaccine for the sa...
49001,50,149559,1v0f2dtx,1,29.176404,what is known about an mrna vaccine for the sa...
49002,50,132765,aju2nr9x,2,28.895280,what is known about an mrna vaccine for the sa...
49003,50,116657,ll76vrr3,3,28.218508,what is known about an mrna vaccine for the sa...


### Create the Encoder

In [7]:
from fast_forward.encoder import TCTColBERTQueryEncoder, TCTColBERTDocumentEncoder
import torch

q_encoder = TCTColBERTQueryEncoder("castorini/tct_colbert-msmarco")
d_encoder = TCTColBERTDocumentEncoder(
    "castorini/tct_colbert-msmarco",
    device="cuda:0" if torch.cuda.is_available() else "cpu",
)

Test the Encoder

In [8]:
q_encoder(["Test query 1", "Test query 2"])

array([[-0.0380525 ,  0.01848466,  0.05137944, ..., -0.04796502,
         0.00918062, -0.03880693],
       [-0.06809073,  0.02582865,  0.09803923, ..., -0.09031374,
         0.00014139, -0.06282968]], dtype=float32)

### Create the Index
*(Warning)* This operation takes a long time to complete!

In [20]:
from fast_forward import OnDiskIndex, Mode, Indexer

ff_index = OnDiskIndex(
    Path("indices/irds:beir_webis-touche2020_v2.h5"), dim=768, query_encoder=q_encoder, mode=Mode.MAXP
)

def docs_iter():
    for d in dataset.get_corpus_iter():
        yield {"doc_id": d["docno"], "text": d["text"]}

ff_indexer = Indexer(ff_index, d_encoder, batch_size=1)
ff_indexer.index_dicts(docs_iter())

ValueError: File irds:beir_webis-touche2020_v2.h5 exists.

If the index is present on disk, we can load it directly in Memory (but requires some RAM)

In [9]:
from fast_forward import OnDiskIndex, Mode

ff_index = OnDiskIndex.load(
    Path("indices/beir-covid-trec_ff.h5"), query_encoder=q_encoder, mode=Mode.MAXP
)
ff_index = ff_index.to_memory()

100%|██████████| 171332/171332 [00:00<00:00, 1039737.15it/s]


# Re-ranking BM25 Results

In [10]:
from fast_forward.util.pyterrier import FFScore

ff_score = FFScore(ff_index)

In [11]:
candidates = (bm25 % 5)(testset.get_topics('text')) # Get the candidates
candidates

Unnamed: 0,qid,docid,docno,rank,score,query
0,1,81848,dv9m19yk,0,12.553108,what is the origin of covid 19
1,1,103419,kgifmjvb,1,12.462951,what is the origin of covid 19
2,1,123191,wmfcey6f,2,12.010235,what is the origin of covid 19
3,1,67367,4dtk1kyh,3,11.783026,what is the origin of covid 19
4,1,106832,cniyembt,4,11.665801,what is the origin of covid 19
...,...,...,...,...,...,...
49000,50,117999,xbze5s3c,0,31.931757,what is known about an mrna vaccine for the sars cov 2 virus
49001,50,149559,1v0f2dtx,1,29.176404,what is known about an mrna vaccine for the sars cov 2 virus
49002,50,132765,aju2nr9x,2,28.895280,what is known about an mrna vaccine for the sars cov 2 virus
49003,50,116657,ll76vrr3,3,28.218508,what is known about an mrna vaccine for the sars cov 2 virus


In [12]:
re_ranked = ff_score(candidates)
re_ranked

Unnamed: 0,qid,docno,score_0,score,query
0,1,dv9m19yk,12.553108,3.103294,what is the origin of covid 19
1,1,kgifmjvb,12.462951,2.432137,what is the origin of covid 19
2,1,wmfcey6f,12.010235,2.650206,what is the origin of covid 19
3,1,4dtk1kyh,11.783026,2.601087,what is the origin of covid 19
4,1,cniyembt,11.665801,2.985767,what is the origin of covid 19
...,...,...,...,...,...
245,50,xbze5s3c,31.931757,2.118201,what is known about an mrna vaccine for the sars cov 2 virus
246,50,1v0f2dtx,29.176404,2.474986,what is known about an mrna vaccine for the sars cov 2 virus
247,50,aju2nr9x,28.895280,1.938545,what is known about an mrna vaccine for the sars cov 2 virus
248,50,ll76vrr3,28.218508,1.588037,what is known about an mrna vaccine for the sars cov 2 virus


In [13]:
from fast_forward.util.pyterrier import FFInterpolate

ff_int = FFInterpolate(alpha=0.5)
ff_int(re_ranked)

Unnamed: 0,qid,docno,query,score
0,1,dv9m19yk,what is the origin of covid 19,7.828201
1,1,kgifmjvb,what is the origin of covid 19,7.447544
2,1,wmfcey6f,what is the origin of covid 19,7.330221
3,1,4dtk1kyh,what is the origin of covid 19,7.192057
4,1,cniyembt,what is the origin of covid 19,7.325784
...,...,...,...,...
245,50,xbze5s3c,what is known about an mrna vaccine for the sars cov 2 virus,17.024979
246,50,1v0f2dtx,what is known about an mrna vaccine for the sars cov 2 virus,15.825695
247,50,aju2nr9x,what is known about an mrna vaccine for the sars cov 2 virus,15.416913
248,50,ll76vrr3,what is known about an mrna vaccine for the sars cov 2 virus,14.903272


In [14]:
# Find best Alpha

devset = pt.get_dataset("irds:beir/fiqa/dev")
pt.GridSearch(
    ~bm25 % 100 >> ff_score >> ff_int,
    {ff_int: {"alpha": [0.05, 0.1, 0.5, 0.9]}},
    devset.get_topics(),
    devset.get_qrels(),
    "map",
    verbose=True,
)
print(ff_int.alpha)

GridScan:   0%|          | 0/4 [00:00<?, ?it/s]

Best map is 0.000000
Best setting is ['<fast_forward.util.pyterrier.FFInterpolate object at 0x328ebf4d0> alpha=0.05']
0.05


# Results

In [33]:
testset.get_topics('text')

Unnamed: 0,qid,query
0,1,should teachers get tenure
1,2,is vaping with e cigarettes safe
2,3,should insider trading be allowed
3,4,should corporal punishment be used in schools
4,5,should social security be privatized
5,6,is a college education worth it
6,7,should felons who have completed their sentence be allowed to vote
7,8,should abortion be legal
8,9,should students have to wear school uniforms
9,10,should any vaccines be required for children


In [15]:
qe_pipeline = bm25 % 5 >> rm3
qe_pipeline(testset.get_topics('text'))

Unnamed: 0,qid,query_0,query
0,1,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337
1,10,has social distancing had an impact on slowing the spread of covid 19,applypipeline:off distanc^0.177928656 state^0.017186088 counti^0.032237958 social^0.182758749 19^0.085714296 mobil^0.026818428 spread^0.132028475 impact^0.085714296 covid^0.085714296 polici^0.016772082 reduct^0.017572282 slow^0.124116361 90^0.015438061
2,11,what are the guidelines for triaging patients infected with coronavirus,applypipeline:off patient^0.120000005 applic^0.031862456 mild^0.031862456 optim^0.037467886 guidelin^0.168292180 modif^0.031862456 triag^0.205597281 minutu^0.031862456 infect^0.120000005 propos^0.037467886 coronaviru^0.120000005 walk^0.031862456 step^0.031862456
3,12,what are best practices in hospitals and at home in maintaining quarantine,applypipeline:off hospit^0.100000009 best^0.100000009 workstat^0.030459607 physic^0.055760063 maintain^0.100000009 radiolog^0.030459607 home^0.183448285 program^0.027345825 older^0.037558597 stai^0.025671918 quarantin^0.136402950 adult^0.040241349 practic^0.132651791
4,13,what are the transmission routes of coronavirus,applypipeline:off china^0.026936658 transmiss^0.287120342 confirm^0.026936658 indirect^0.026936658 rout^0.291076064 coronaviru^0.200000018 ncov^0.040404987 contact^0.026936658 possibl^0.021379795 droplet^0.026936658 outbreak^0.025335528
5,14,what evidence is there related to covid 19 super spreaders,applypipeline:off literatur^0.015347854 limit^0.030695708 hide^0.015347854 evid^0.120000005 contain^0.015347854 busi^0.043672454 19^0.120000005 super^0.225854248 spreader^0.225854248 equal^0.018987225 epidem^0.030695708 covid^0.120000005 densiti^0.018196857
6,15,how long can the coronavirus live outside the body,applypipeline:off research^0.030994300 predict^0.044318952 long^0.145977214 clinician^0.032442633 mous^0.026807157 live^0.180789247 bodi^0.152903169 ag^0.048136685 damag^0.048663948 can^0.120000005 coronaviru^0.120000005 lifespan^0.048966635
7,16,how long does coronavirus remain stable on surfaces,applypipeline:off long^0.120000005 biocid^0.035049111 remain^0.144800231 stabl^0.153562516 viabl^0.032503329 surfac^0.220687523 07^0.024800230 coronaviru^0.120000005 formul^0.026286833 environment^0.041897357 dai^0.041897357 229e^0.038515508
8,17,are there any clinical trials available for the coronavirus,applypipeline:off advers^0.024168039 clinic^0.200000018 ophthalm^0.036252063 period^0.036252063 manag^0.036252063 epidem^0.028505184 coronaviru^0.200000018 examin^0.024168039 trial^0.323622465 pneumonia^0.024168039 mask^0.030360002 subject^0.036252063
9,18,what are the best masks for preventing infection by covid 19,applypipeline:off best^0.124929860 vaccin^0.029904731 recommend^0.031456783 surgic^0.019954965 wear^0.034505464 19^0.100000009 infect^0.100000009 covid^0.100000009 prevent^0.149859697 colleg^0.019954965 pwe^0.023003642 face^0.040256370 mask^0.226173595


In [16]:
import re

def _remove_pollution(q) -> str:
    q_old = q["query"].replace('applypipeline:off', '')
    return q["query_1"] + " " + re.sub(r'\^(\d)+\.(\d)+', '', q_old)

# (qe_pipeline >> pt.apply.query(_remove_pollution))(testset.get_topics())

pipeline = qe_pipeline >> pt.apply.query(_remove_pollution) >> bm25

pipeline(testset.get_topics('text'))

Unnamed: 0,qid,docid,docno,rank,score,query_1,query_0,query
0,1,103419,kgifmjvb,0,56.978069,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337,what is the origin of covid 19 countri china origin 118 19 covid travel track earli earliest global canada
1,1,123191,wmfcey6f,1,53.701980,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337,what is the origin of covid 19 countri china origin 118 19 covid travel track earli earliest global canada
2,1,67667,of9wlhga,2,24.519055,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337,what is the origin of covid 19 countri china origin 118 19 covid travel track earli earliest global canada
3,1,69059,8gtnbm1c,3,22.497941,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337,what is the origin of covid 19 countri china origin 118 19 covid travel track earli earliest global canada
4,1,86569,ehmd66ub,4,22.163652,what is the origin of covid 19,applypipeline:off countri^0.022980893 china^0.045961786 origin^0.324229300 118^0.034471337 19^0.200000018 covid^0.200000018 travel^0.045961786 track^0.022980893 earli^0.022980893 earliest^0.022980893 global^0.022980893 canada^0.034471337,what is the origin of covid 19 countri china origin 118 19 covid travel track earli earliest global canada
...,...,...,...,...,...,...,...,...
49995,9,86718,qs9d45ky,995,10.055451,how has covid 19 affected canada,applypipeline:off teschoviru^0.023870965 individu^0.023387097 19^0.150000006 lesion^0.017903225 affect^0.250000000 mp^0.052258059 covid^0.150000006 nation^0.023870965 represent^0.034838706 parliament^0.034838706 canada^0.215161294 pig^0.023870965,how has covid 19 affected canada teschoviru individu 19 lesion affect mp covid nation represent parliament canada pig
49996,9,32550,0lphe922,996,10.051526,how has covid 19 affected canada,applypipeline:off teschoviru^0.023870965 individu^0.023387097 19^0.150000006 lesion^0.017903225 affect^0.250000000 mp^0.052258059 covid^0.150000006 nation^0.023870965 represent^0.034838706 parliament^0.034838706 canada^0.215161294 pig^0.023870965,how has covid 19 affected canada teschoviru individu 19 lesion affect mp covid nation represent parliament canada pig
49997,9,1744,j1x7js5z,997,10.051462,how has covid 19 affected canada,applypipeline:off teschoviru^0.023870965 individu^0.023387097 19^0.150000006 lesion^0.017903225 affect^0.250000000 mp^0.052258059 covid^0.150000006 nation^0.023870965 represent^0.034838706 parliament^0.034838706 canada^0.215161294 pig^0.023870965,how has covid 19 affected canada teschoviru individu 19 lesion affect mp covid nation represent parliament canada pig
49998,9,106100,i3y5l2we,998,10.051073,how has covid 19 affected canada,applypipeline:off teschoviru^0.023870965 individu^0.023387097 19^0.150000006 lesion^0.017903225 affect^0.250000000 mp^0.052258059 covid^0.150000006 nation^0.023870965 represent^0.034838706 parliament^0.034838706 canada^0.215161294 pig^0.023870965,how has covid 19 affected canada teschoviru individu 19 lesion affect mp covid nation represent parliament canada pig


In [44]:
(~bm25 % 1000)(testset.get_topics())

There are multiple query fields available: ('text', 'description', 'narrative'). To use with pyterrier, provide variant or modify dataframe to add query column.


Unnamed: 0,qid,docid,docno,rank,score,query
0,1,143806,51530f3f-2019-04-18T18:15:02Z-00004-000,0,31.770993,should teachers get tenure
1,1,164415,b0680508-2019-04-18T13:48:51Z-00002-000,1,31.583122,should teachers get tenure
2,1,4619,c065954f-2019-04-18T14:32:52Z-00003-000,2,31.521730,should teachers get tenure
3,1,4617,c065954f-2019-04-18T14:32:52Z-00001-000,3,31.476861,should teachers get tenure
4,1,163479,ff0947ec-2019-04-18T12:23:12Z-00000-000,4,31.385355,should teachers get tenure
...,...,...,...,...,...,...
48995,50,24045,630f7c6f-2019-04-18T12:52:49Z-00001-000,995,9.781759,should everyone get a universal basic income
48996,50,284651,1d684498-2019-04-18T17:05:49Z-00001-000,996,9.781667,should everyone get a universal basic income
48997,50,181658,b4c02573-2019-04-18T11:34:49Z-00000-000,997,9.780855,should everyone get a universal basic income
48998,50,320637,7539ed46-2019-04-18T11:46:02Z-00004-000,998,9.780264,should everyone get a universal basic income


In [20]:
pt.Experiment(
    [
        bm25,
        bm25 >> rm3 >> bm25,
        bm25 % 1 >> rm3 >> bm25,
        bm25 % 1000 >> ff_score >> ff_int,
        pipeline >> ff_score >> ff_int
    ],
    testset.get_topics('text'),
    testset.get_qrels(),
    eval_metrics=[RR @ 10, nDCG @ 10, MAP @ 100],
    names=[
        "BM25",
        "RM3",
        "RM3 % 1",
        "BM25 >> FF",
        "BM25 >> RM3 % 5 >> FF"
    ],
)

Unnamed: 0,name,RR@10,nDCG@10,AP@100
0,BM25,0.886857,0.640083,0.085416
1,RM3,0.844,0.660581,0.088682
2,RM3 % 1,0.884667,0.657229,0.08424
3,BM25 >> FF,0.598413,0.322953,0.028592
4,BM25 >> RM3 % 5 >> FF,0.767333,0.472624,0.043409
