# RankLLaMA

This jupyter notebook uses [rankllama](https://huggingface.co/castorini/rankllama-v1-7b-lora-passage/) to re-rank documents in the context [ReNeuIR 2024](https://reneuir.org/). Please look at the [corresponding publications](https://arxiv.org/pdf/2310.08319) for more details.

### Step 1: Import dependencies

In [2]:
import os
import torch
import pandas as pd
from tqdm import tqdm
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft import PeftModel, PeftConfig
from tira.third_party_integrations import ir_datasets, persist_and_normalize_run
from huggingface_hub import snapshot_download

### Step 2: Load and Transform the Dataset

In [3]:
# the "from tira.third_party_integrations import ir_datasets" import patches "ir_datasets.load"
# so that it loads the dataset injected into the tira sandbox when executed within the sandbox.
# I.e., we only ensure that it runs on a minimal spot-check dataset here.
dataset = ir_datasets.load('reneuir-2024/re-rank-spot-check-20240624-training')

In [4]:
qid_to_rerank_data = {}

for i in dataset.scoreddocs_iter():
    if i.query_id not in qid_to_rerank_data:
        qid_to_rerank_data[i.query_id] = {'query': i.query.default_text(), 'search_results': []}
    
    qid_to_rerank_data[i.query_id]['search_results']  += [
        {'docid': i.doc_id, 'text': i.document.default_text(), 'score': i.score}
    ]


### Step 3: Re-Rank

In [6]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Will use device:', device)

def get_model(peft_model_name):
    config = PeftConfig.from_pretrained(peft_model_name)
    base_model = AutoModelForSequenceClassification.from_pretrained(config.base_model_name_or_path, num_labels=1, attn_implementation="flash_attention_2")
    model = PeftModel.from_pretrained(base_model, peft_model_name)
    model = model.merge_and_unload()
    model.eval()
    return model

# Load the tokenizer and model
tokenizer_path = snapshot_download('meta-llama/Llama-2-7b-hf', local_files_only=True)
print(f'Use "{tokenizer_path}" as tokenizer.')
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id


model_path = snapshot_download('castorini/rankllama-v1-7b-lora-passage', local_files_only=True)
print(f'Use "{model_path}" as model.')
model = get_model(model_path)
model.to(device)
model.config.pad_token_id = tokenizer.pad_token_id

print('Ranker is', model)

Will use device: cuda
Use "/root/.cache/huggingface/hub/models--meta-llama--Llama-2-7b-hf/snapshots/01c7f73d771dfac7d292323805ebc428287df4f9" as tokenizer.
Use "/root/.cache/huggingface/hub/models--castorini--rankllama-v1-7b-lora-passage/snapshots/a079eadc8530520a8259e694212fd201531753a9" as model.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Some weights of LlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Llama-2-7b-hf and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Ranker is LlamaForSequenceClassification(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaFlashAttention2(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm()
        (post_attention_layernorm): LlamaRMSNorm()
      )
    )
    (norm): Llama

In [7]:
batch_size = 16
run = []
for qid in tqdm(qid_to_rerank_data):
    query = qid_to_rerank_data[qid]['query']
    search_results = qid_to_rerank_data[qid]['search_results']
    rerank_unsorted_results = []
    input_texts = [f'query: {query} document: {doc["text"]}' for doc in search_results]
    for i in range(0, len(input_texts), batch_size):
        batch = input_texts[i:i+batch_size]
        inputs = tokenizer(batch, padding=True, truncation=True, return_tensors='pt', max_length=512)
        with torch.autocast(device_type=device, dtype=torch.bfloat16):
            with torch.no_grad():
                outputs = model(**inputs.to(device))
        scores = outputs.logits.squeeze().cpu().tolist()
        if len(batch) == 1:
            # if we only have one input element in the batch, scores contains the float, not a list of floats.
            scores = [scores]
        
        rerank_unsorted_results += [{'docid': search_results[i]['docid'], 'score': scores[i]} for i in range(len(scores))]
    rerank_sorted_results = sorted(rerank_unsorted_results, key=lambda x: x['score'], reverse=True)
    # add qid score and docid to run
    for result in rerank_sorted_results:
        run += [{"qid": qid, "score": result['score'], "docno": result['docid']}]
        

  0%|                                                                                                                       | 0/5 [00:00<?, ?it/s]We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:10<00:00,  2.00s/it]


### Step 4: Persist run file

In [8]:
run = pd.DataFrame(run)

persist_and_normalize_run(run, system_name=f'rankllama', default_output='.')

The run file is normalized outside the TIRA sandbox, I will store it at ".".
Done. run file is stored under "./run.txt".
