In [1]:
import pandas as pd
import os
import gzip
import json
from tqdm import tqdm

batch_size = os.environ.get('LIT5_BATCH_SIZE', 18)
total_n_rerank_passages = os.environ.get('LIT5_RERANK_PASSAGES', 100)
windowsize = os.environ.get('LIT5_WINDOW_SIZE', 20)
stride = os.environ.get('LIT5_STRIDE', 10)
model_path = os.environ.get('LIT5_MODEL_PATH', 'castorini/LiT5-Distill-base')
text_maxlength = os.environ.get('LIT5_TEXT_MAX_LENGTH', 150)


In [2]:
input_dataset = "/tmp/rerank.jsonl.gz"
output_file = "/tmp/run.txt"

if os.environ.get('TIRA_INPUT_DATASET', None):
    # We are inside the TIRA sandbox, we process the injected dataset
    input_dataset = os.environ.get('TIRA_INPUT_DATASET') + '/rerank.jsonl.gz'
    output_file = os.environ.get('TIRA_OUTPUT_DIR') + '/run.txt'
else:
    # We are outside the TIRA sandbox, we process a small cunstructed hello world example
    scored_docs = pd.DataFrame([
        {"qid": "1", "query": "hubble telescope achievements", "docno": "doc-3", "text": "The Hubble telescope discovered two moons of Pluto, Nix and Hydra.", "rank": 1, "score": 10},
        {"qid": "1", "query": "hubble telescope achievements", "docno": "doc-4", "text": "Edwin Hubble, an astronomer with great achievement, completely reimagined our place in the universe (the telescope is named by him).", "rank": 2, "score": 9},
        {"qid": "2", "query": "how to exit vim?", "docno": "doc-1", "text": "Press ESC key, then the : (colon), and type the wq command after the colon and hit the Enter key to save and leave Vim.", "rank": 1, "score": 10},
        {"qid": "2", "query": "how to exit vim?", "docno": "doc-2", "text": "In Vim, you can always press the ESC key on your keyboard to enter the normal mode in the Vim editor.", "rank": 2, "score": 9},
        {"qid": "3", "query": "signs heart attack", "docno": "doc-5", "text": "Common heart attack symptoms include: (1) Chest pain, (2) Pain or discomfort that spreads to the shoulder, arm, back, neck, jaw, teeth or sometimes the upper belly, etc.", "rank": 1, "score": 10},
        {"qid": "3", "query": "signs heart attack", "docno": "doc-6", "text": "A heart attack happens when the flow of blood that brings oxygen to your heart muscle suddenly becomes blocked. ", "rank": 2, "score": 9},
    ])
    scored_docs.to_json(input_dataset, lines=True, orient='records')

In [3]:
def transform_to_lit5_dataset(input_dataset, output_file, max_document_length=3500):
    qid_to_lit5_format = {}

    with gzip.open(input_dataset, 'rt') as input:
        for l in tqdm(input, 'Reformat input dataset.'):
            l = json.loads(l)
            if l['qid'] not in qid_to_lit5_format:
                qid_to_lit5_format[l['qid']] = {"question": l['query'], "id": l['qid'], "ctxs": []}
            qid_to_lit5_format[l['qid']]["ctxs"] += [{"docid": l['docno'], "text": l['text'][:max_document_length]}]

    pd.DataFrame(qid_to_lit5_format.values()).to_json(output_file, lines=True, orient='records')


transform_to_lit5_dataset(input_dataset, '/tmp/rerank-lit5.jsonl')
input_dataset = '/tmp/rerank-lit5.jsonl'

Reformat input dataset.: 6it [00:00, 6466.04it/s]


In [5]:
print('Start re-ranking with:')
print('\tbatch_size:', batch_size)
print('\ttotal_n_rerank_passages:', total_n_rerank_passages)
print('\twindowsize:', windowsize)
print('\tstride:', stride)
print('\tmodel_path:', model_path)
print('\ttext_maxlength:', text_maxlength)

Start re-ranking with:
	batch_size: 18
	total_n_rerank_passages: 100
	windowsize: 20
	stride: 10
	model_path: castorini/LiT5-Distill-base
	text_maxlength: 150


In [6]:
!python3 ../castorini-list-in-t5/FiD/LiT5-Distill.py \
    --model_path {model_path} \
    --batch_size {batch_size} \
    --eval_data {input_dataset} \
    --n_passages {windowsize} \
    --runfile_path {output_file} \
    --text_maxlength {text_maxlength} \
    --answer_maxlength 100 \
    --stride {stride} \
    --n_rerank_passages {total_n_rerank_passages} \
    --bfloat16

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Start Inference
Reranking passages: 80 to 100
Reranking passages: 70 to 90
Reranking passages: 60 to 80
Reranking passages: 50 to 70
Reranking passages: 40 to 60
Reranking passages: 30 to 50
Reranking passages: 20 to 40
Reranking passages: 10 to 30
Reranking passages: 0 to 20


In [10]:
print('Re-ranking is done. First 4 entries of the run file:')
!head -4 {output_file}

Re-ranking is done. First 4 entries of the run file:
1 Q0 doc-4 1 1.0 RankFiD
1 Q0 doc-3 2 0.5 RankFiD
2 Q0 doc-1 1 1.0 RankFiD
2 Q0 doc-2 2 0.5 RankFiD
