# ReRank with TinyBERT

### Step 0: Import dependencies

In [1]:
from tira.third_party_integrations import persist_and_normalize_run, load_rerank_data
from sentence_transformers import CrossEncoder
from tqdm import tqdm
import os

### Step 1: Load the re-rank data

In [10]:
# In the TIRA sandbox, this is the injected re-ranking dataset, injected via the environment variable TIRA_INPUT_DIRECTORY
re_rank_dataset = load_rerank_data(default='reneuir-2024/re-rank-spot-check-20240624-training')

I will use a small hardcoded example located in reneuir-2024/re-rank-spot-check-20240624-training.


### Step 3: Load the model

We pass the model via an environment variable.


In [2]:
DEFAULT_MODEL = 'webis/tiny-bert-ranker-l2-bm25'
model_name = os.environ.get('MODEL', DEFAULT_MODEL)
print(f'I will use the model {model_name}.')

I will use the model webis/tiny-bert-ranker-l2-bm25.


In [25]:
model = CrossEncoder(model_name, num_labels=1)

In [27]:
print('Model is:', model)

Model is: <sentence_transformers.cross_encoder.CrossEncoder.CrossEncoder object at 0x7f9b8466eb00>


### Step 4: Make all predictions

In [18]:
global_min_bm25 = 0
global_max_bm25 = 0

for _, i in re_rank_dataset.iterrows():
  if i['score'] > global_max_bm25:
    global_max_bm25 = i['score']

input_data = []

for _, i in tqdm(re_rank_dataset.iterrows(), 'Transform dataset for predictions'):
    normalized_score = (i['score'] - global_min_bm25) / (global_max_bm25 - global_min_bm25)
    normalized_score = int(normalized_score * 100)

    input_data += [["{} [SEP] {}".format(normalized_score, i['query']), i['text']]]


Transform dataset for predictions: 500it [00:00, 5254.20it/s]


In [29]:
print('Make predictions')

Make predictions


In [30]:
re_rank_dataset['score'] = model.predict(input_data)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


### Step 5: Write run file

In [31]:
persist_and_normalize_run(re_rank_dataset, model_name, default_output='./run.txt')

The run file is normalized outside the TIRA sandbox, I will store it at "./run.txt".
Done. run file is stored under "./run.txt".
