# Infer-Retrieve-Rerank Llama Pack

<a href="https://colab.research.google.com/github/run-llama/llama-hub/blob/main/llama_hub/llama_packs/research/infer_retrieve_rerank/infer_retrieve_rerank.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is our implementation of the paper ["In-Context Learning for Extreme Multi-Label Classification](https://arxiv.org/pdf/2401.12178.pdf) by Oosterlinck et al.

The paper proposes "infer-retrieve-rerank", a simple paradigm using frozen LLM/retriever models that can do "extreme"-label classification (the label space is huge).
1. Given a user query, use an LLM to predict an initial set of labels.
2. For each prediction, retrieve the actual label from the corpus.
3. Given the final set of labels, rerank them using an LLM.

All of these can be implemented as LlamaIndex abstractions. In this notebook we show you how to build "infer-retrieve-rerank" from scratch but also how to build it as a LlamaPack.

## Try out a Dataset

We use the BioDEX dataset as mentioned in the paper.

Here is the [link to the paper](https://arxiv.org/pdf/2305.13395.pdf). Here is the [link to the Github repo](https://github.com/KarelDO/BioDEX).

In [None]:
# dataset = datasets.load_dataset("BioDEX/raw_dataset")
import datasets

# load the report-extraction dataset
dataset = datasets.load_dataset("BioDEX/BioDEX-ICSR")

In [48]:
dataset

DatasetDict({
    train: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],
        num_rows: 9624
    })
    validation: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext_license', 'title_normalized', 'issue', 'pages', 'journal', 'authors', 'pubdate', 'doi', 'affiliations', 'medline_ta', 'nlm_unique_id', 'issn_linking', 'country', 'mesh_terms', 'publication_types', 'chemical_list', 'keywords', 'references', 'delete', 'pmc', 'other_id', 'safetyreportid', 'fulltext_processed'],
        num_rows: 2407
    })
    test: Dataset({
        features: ['title', 'abstract', 'fulltext', 'target', 'pmid', 'fulltext

In [68]:
from llama_index import get_tokenizer
import re
from typing import Set

tokenizer = get_tokenizer()


sample_size = 5

def get_reactions_row(raw_target: str) -> List[str]:
    """Get reactions from a single row."""
    reaction_pattern = re.compile(r"reactions:\s*(.*)")
    reaction_match = reaction_pattern.search(raw_target)
    if reaction_match:
        reactions = reaction_match.group(1).split(",")
        reactions = [r.strip().lower() for r in reactions]
    else:
        reactions = []
    return reactions


def get_reactions_set(dataset) -> Set[str]:
    """Get set of all reactions."""
    reactions = set()
    for data in dataset["train"]:
        reactions.update(set(get_reactions_row(data["target"])))
    return reactions


def get_samples(dataset, sample_size: int = 5):
    """Get processed sample.

    Contains source text and also the reaction label.

    Parse reaction text to specifically extract reactions.
    
    """
    samples = []
    for idx, data in enumerate(dataset["train"]):
        if idx >= sample_size:
            break
        text = data["fulltext_processed"]
        raw_target = data["target"]

        reactions = get_reactions_row(raw_target)

        samples.append({
            "text": text,
            "reactions": reactions
        })
    return samples

### Index each Reaction with a Vector Index

In [69]:
import random

all_reactions = get_reactions_set(dataset)
random.sample(all_reactions, 5)

since Python 3.9 and will be removed in a subsequent version.
  random.sample(all_reactions, 5)


['fixed eruption',
 'abdominal compartment syndrome',
 'foreign body reaction',
 'intermittent claudication',
 'chorioretinal disorder']

In [70]:
from llama_index.schema import TextNode
from llama_index.embeddings import OpenAIEmbedding
from llama_index.ingestion import IngestionPipeline
from llama_index import VectorStoreIndex

reaction_nodes = [TextNode(text=r) for r in all_reactions]
pipeline = IngestionPipeline(transformations=[OpenAIEmbedding()])
reaction_nodes = await pipeline.arun(documents=reaction_nodes)

index = VectorStoreIndex(reaction_nodes)

In [None]:
reaction_nodes[0].embedding

In [76]:
reaction_retriever = index.as_retriever(similarity_top_k=2)

In [77]:
reaction_retriever.retrieve('abdominal')

[NodeWithScore(node=TextNode(id_='e6412990-2c03-40d1-bdb9-586e349b1293', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='abdominal pain', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=0.9158768529615556),
 NodeWithScore(node=TextNode(id_='cf067b02-9418-4588-8f98-6d91f91499b4', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={}, text='abdominal symptom', start_char_idx=None, end_char_idx=None, text_template='{metadata_str}\n\n{content}', metadata_template='{key}: {value}', metadata_seperator='\n'), score=0.9104600840524487)]

## Define Infer-Retrieve-Rerank Pipeline

Here we define the core components needed for the infer-retrieve-rerank pipeline.

In [97]:
from llama_index.retrievers import BaseRetriever
from llama_index.llms.llm import LLM
from llama_index.llms import OpenAI
from llama_index.prompts import PromptTemplate
from llama_index.query_pipeline import QueryPipeline
from llama_index.postprocessor.types import BaseNodePostprocessor
from llama_index.postprocessor.rankGPT_rerank import RankGPTRerank
from llama_index.output_parsers import ChainableOutputParser
from typing import List

In [98]:
infer_prompt_str = """\

Your job is to output a list of predictions given context from a given piece of text. The text context,
and information regarding the set of valid predictions is given below. 

Return the predictions as a comma-separated list of strings.

Text Context:
{doc_context}

Prediction Info:
{pred_context}

Predictions: """

infer_prompt = PromptTemplate(infer_prompt_str)

In [99]:
class PredsOutputParser(ChainableOutputParser):
    """Predictions output parser."""

    def parse(self, output: str) -> List[str]:
        """Parse predictions."""
        tokens = output.split(",")
        return [t.strip() for t in tokens]

preds_output_parser = PredsOutputParser()

In [109]:
rerank_str = """\
Given a piece of text, rank the {num} passages above based on their relevance \
to this piece of text. The passages \
should be listed in descending order using identifiers. \
The most relevant passages should be listed first. \
The output format should be [] > [], e.g., [1] > [2]. \
Only response the ranking results, \
do not say any word or explain. \

Here is a given piece of text: {query}. 

"""
rerank_prompt = PromptTemplate(rerank_str)

In [113]:
def infer_retrieve_rerank(
    query: str,
    retriever: BaseRetriever,
    llm: LLM,
    pred_context: str,
    reranker_top_n: int = 3
):
    """Infer retrieve rerank."""
    infer_prompt_c = infer_prompt.as_query_component(
        partial={"pred_context": pred_context}
    )
    infer_pipeline = QueryPipeline(chain=[infer_prompt_c, llm, preds_output_parser])
    preds = infer_pipeline.run(query)

    print(f"PREDS: {preds}")
    all_nodes = []
    for pred in preds:
        nodes = retriever.retrieve(str(pred))
        all_nodes.extend(nodes)

    reranker = RankGPTRerank(
        llm=llm,
        top_n=reranker_top_n,
        rankgpt_rerank_prompt=rerank_prompt,
        # verbose=True,
    )
    reranked_nodes = reranker.postprocess_nodes(all_nodes, query_str=query)
    return [n.get_content() for n in reranked_nodes]

    

### Run Over Sample Data

In [114]:
samples = get_samples(dataset, sample_size=5)

In [116]:
reaction_retriever = index.as_retriever(similarity_top_k=2)
llm = OpenAI(model="gpt-3.5-turbo-16k")
pred_context = """\
The output predictins should be a list of comma-separated adverse \
drug reactions. \
"""

reranker_top_n = 10

pred_reactions = []
gt_reactions = []
for idx, sample in enumerate(samples):
    print(idx)
    cur_pred_reactions = infer_retrieve_rerank(
        sample["text"],
        reaction_retriever,
        llm,
        pred_context,
        reranker_top_n=reranker_top_n
    )
    cur_gt_reactions = sample["reactions"]

    pred_reactions.append(cur_pred_reactions)
    gt_reactions.append(cur_gt_reactions)
    break

0
PREDS: ['fluid overload', 'acute respiratory distress syndrome', 'anxiety', 'delirium', 'myocardial insufficiency', 'hypervolemia', 'hypovolemia', 'pneumonia', 'pleural effusion', 'heart failure', 'respiratory distress', 'allergic reaction', 'eosinophilia', 'diarrhea', 'rash']
After Reranking, new rank list for nodes: [0, 2, 20, 18, 9, 10, 8, 16, 14, 26, 27, 24, 25, 29, 28, 6, 7, 4, 5, 12, 11, 13, 1, 3, 21, 22, 23, 15, 17, 19]

In [117]:
pred_reactions[0]

['fluid overload',
 'acute respiratory distress syndrome',
 'respiratory distress',
 'cardiac failure',
 'myocardial infarction',
 'hypervolaemia',
 'cardiovascular insufficiency',
 'pleural effusion',
 'pneumonia',
 'diarrhoea']

In [118]:
gt_reactions[0]

['diarrhoea', 'drug hypersensitivity', 'rash']