diff --git a/docs/api/retrieval_model_clients/ColBERTv2.md b/docs/api/retrieval_model_clients/ColBERTv2.md index 2dd31bef8c..a8fea94928 100644 --- a/docs/api/retrieval_model_clients/ColBERTv2.md +++ b/docs/api/retrieval_model_clients/ColBERTv2.md @@ -49,3 +49,81 @@ retrieval_response = colbertv2_wiki17_abstracts('When was the first FIFA World C for result in retrieval_response: print("Text:", result['text'], "\n") ``` + +# dspy.ColBERTv2RetrieverLocal + +This is taken from the official documentation of [Colbertv2](https://github.com/stanford-futuredata/ColBERT/tree/main) following the [paper](https://arxiv.org/abs/2112.01488). + +You can install Colbertv2 by the following instructions from [here](https://github.com/stanford-futuredata/ColBERT?tab=readme-ov-file#installation) + +### Constructor +The constructor initializes the ColBERTv2 as a local retriever object. You can initialize a server instance from your ColBERTv2 local instance using the code snippet from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py) + +```python +class ColBERTv2RetrieverLocal: + def __init__( + self, + passages:List[str], + colbert_config=None, + load_only:bool=False): +``` + +**Parameters** +- `passages` (_List[str]_): List of passages to be indexed +- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None. +- `load_only` (_Boolean_): whether to load the index or build and then load. Defaults to False. + +The `colbert_config` object is required for ColBERTv2, and it can be imported from `from colbert.infra.config import ColBERTConfig`. You can find the descriptions of config attributes from [here](https://github.com/stanford-futuredata/ColBERT/blob/main/colbert/infra/config/settings.py) + +### Methods + +#### `forward(self, query:str, k:int, **kwargs) -> Union[list[str], list[dotdict]]` + +It retrieves relevant passages from the index based on the query. If you already have a local index, then you can pass the `load_only` flag as `True` and change the `index` attribute of ColBERTConfig to the local path. Also, make sure to change the `checkpoint` attribute of ColBERTConfig to the embedding model that you used to build the index. + +**Parameters:** +- `query` (_str_): Query string used for retrieval. +- `k` (_int_, _optional_): Number of passages to retrieve. Defaults to 7 + +It returns a `Prediction` object for each query + +```python +Prediction( + pid=[33, 6, 47, 74, 48], + passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.'] +) +``` +# dspy.ColBERTv2RerankerLocal + +You can also use ColBERTv2 as a reranker in DSPy. + +### Constructor + +```python +class ColBERTv2RerankerLocal: + + def __init__( + self, + colbert_config=None, + checkpoint:str='bert-base-uncased'): +``` + +**Parameters** +- `colbert_config` (_ColBERTConfig_, _Optional_): colbert config for building and searching. Defaults to None. +- `checkpoint` (_str_): Embedding model for embeddings the documents and query + +### Methods +#### `forward(self,query:str,passages:List[str])` + +Based on a query and list of passages, it reranks the passages and returns the scores along with the passages ordered in descending order based on the similarity scores. + +**Parameters:** +- `query` (_str_): Query string used for reranking. +- `passages` (_List[str]_): List of passages to be reranked + +It returns the similarity scores array and you can link it to the passages by + +```python +for idx in np.argsort(scores_arr)[::-1]: + print(f"Passage = {passages[idx]} --> Score = {scores_arr[idx]}") +``` \ No newline at end of file diff --git a/dsp/modules/__init__.py b/dsp/modules/__init__.py index a31c72911f..48b520e827 100644 --- a/dsp/modules/__init__.py +++ b/dsp/modules/__init__.py @@ -9,7 +9,7 @@ from .clarifai import * from .cloudflare import * from .cohere import * -from .colbertv2 import ColBERTv2 +from .colbertv2 import ColBERTv2, ColBERTv2RerankerLocal, ColBERTv2RetrieverLocal from .databricks import * from .dummy_lm import * from .google import * diff --git a/dsp/modules/colbertv2.py b/dsp/modules/colbertv2.py index 8ff3c16225..67b246c5e5 100644 --- a/dsp/modules/colbertv2.py +++ b/dsp/modules/colbertv2.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Optional, Union +from typing import Any, List, Optional, Union import requests @@ -74,3 +74,120 @@ def colbertv2_post_request_v2_wrapped(*args, **kwargs): colbertv2_post_request = colbertv2_post_request_v2_wrapped + +class ColBERTv2RetrieverLocal: + def __init__(self,passages:List[str],colbert_config=None,load_only:bool=False): + """Colbertv2 retriever module + + Args: + passages (List[str]): list of passages + colbert_config (ColBERTConfig, optional): colbert config for building and searching. Defaults to None. + load_only (bool, optional): whether to load the index or build and then load. Defaults to False. + """ + assert colbert_config is not None, "Please pass a valid colbert_config, which you can import from colbert.infra.config import ColBERTConfig and modify it" + self.colbert_config = colbert_config + + assert self.colbert_config.checkpoint is not None, "Please pass a valid checkpoint like colbert-ir/colbertv2.0, which you can modify in the ColBERTConfig with attribute name checkpoint" + self.passages = passages + + assert self.colbert_config.index_name is not None, "Please pass a valid index_name, which you can modify in the ColBERTConfig with attribute name index_name" + self.passages = passages + + if not load_only: + print(f"Building the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}") + self.build_index() + + print(f"Loading the index for experiment {self.colbert_config.experiment} with index name {self.colbert_config.index_name}") + self.searcher = self.get_index() + + def build_index(self): + + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + + from colbert import Indexer + from colbert.infra import Run, RunConfig + with Run().context(RunConfig(nranks=self.colbert_config.nranks, experiment=self.colbert_config.experiment)): + indexer = Indexer(checkpoint=self.colbert_config.checkpoint, config=self.colbert_config) + indexer.index(name=self.colbert_config.index_name, collection=self.passages, overwrite=True) + + def get_index(self): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + + from colbert import Searcher + from colbert.infra import Run, RunConfig + + with Run().context(RunConfig(experiment=self.colbert_config.experiment)): + searcher = Searcher(index=self.colbert_config.index_name, collection=self.passages) + return searcher + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) + + def forward(self,query:str,k:int=7,**kwargs): + import torch + + if kwargs.get("filtered_pids"): + filtered_pids = kwargs.get("filtered_pids") + assert type(filtered_pids) == List[int], "The filtered pids should be a list of integers" + device = "cuda" if torch.cuda.is_available() else "cpu" + results = self.searcher.search( + query, + #Number of passages to receive + k=k, + #Passing the filter function of relevant + filter_fn=lambda pids: torch.tensor( + [pid for pid in pids if pid in filtered_pids],dtype=torch.int32).to(device)) + else: + searcher_results = self.searcher.search(query, k=k) + results = [] + for pid,rank,score in zip(*searcher_results): + results.append(dotdict({'long_text':self.searcher.collection[pid],'score':score,'pid':pid})) + return results + +class ColBERTv2RerankerLocal: + + def __init__(self,colbert_config=None,checkpoint:str='bert-base-uncased'): + try: + import colbert + except ImportError: + print("Colbert not found. Please check your installation or install the module using pip install colbert-ai[faiss-gpu,torch].") + """_summary_ + + Args: + colbert_config (ColBERTConfig, optional): Colbert config. Defaults to None. + checkpoint_name (str, optional): checkpoint for embeddings. Defaults to 'bert-base-uncased'. + """ + self.colbert_config = colbert_config + self.checkpoint = checkpoint + self.colbert_config.checkpoint = checkpoint + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + return self.forward(*args, **kwargs) + + def forward(self,query:str,passages:List[str]=[]): + assert len(passages) > 0, "Passages should not be empty" + + import numpy as np + from colbert.modeling.colbert import ColBERT + from colbert.modeling.tokenization.doc_tokenization import DocTokenizer + from colbert.modeling.tokenization.query_tokenization import QueryTokenizer + + self.colbert_config.nway = len(passages) + query_tokenizer = QueryTokenizer(self.colbert_config,verbose=1) + doc_tokenizer = DocTokenizer(self.colbert_config) + query_ids,query_masks = query_tokenizer.tensorize([query]) + doc_ids,doc_masks = doc_tokenizer.tensorize(passages) + + col = ColBERT(self.checkpoint,self.colbert_config) + Q = col.query(query_ids,query_masks) + DOC_IDS,DOC_MASKS = col.doc(doc_ids,doc_masks,keep_dims='return_mask') + Q_duplicated = Q.repeat_interleave(len(passages), dim=0).contiguous() + tensor_scores = col.score(Q_duplicated,DOC_IDS,DOC_MASKS) + passage_score_arr = np.array([score.cpu().detach().numpy().tolist() for score in tensor_scores]) + return passage_score_arr \ No newline at end of file diff --git a/dsp/primitives/search.py b/dsp/primitives/search.py index bfabed4955..1ad9a07cde 100644 --- a/dsp/primitives/search.py +++ b/dsp/primitives/search.py @@ -1,9 +1,11 @@ +import logging from collections.abc import Iterable import numpy as np import dsp +logger = logging.getLogger(__name__) def retrieve(query: str, k: int, **kwargs) -> list[str]: """Retrieves passages from the RM for the query and returns the top k passages.""" @@ -15,12 +17,25 @@ def retrieve(query: str, k: int, **kwargs) -> list[str]: # TODO: we should unify the type signatures of dspy.Retriever passages = [passages] passages = [psg.long_text for psg in passages] - + if dsp.settings.reranker: passages_cs_scores = dsp.settings.reranker(query, passages) passages_cs_scores_sorted = np.argsort(passages_cs_scores)[::-1] passages = [passages[idx] for idx in passages_cs_scores_sorted] + + return passages +def retrievewithMetadata(query: str, k: int, **kwargs) -> list[str]: + """Retrieves passages from the RM for the query and returns the top k passages.""" + + if not dsp.settings.rm: + raise AssertionError("No RM is loaded.") + passages = dsp.settings.rm(query, k=k, **kwargs) + if not isinstance(passages, Iterable): + # it's not an iterable yet; make it one. + # TODO: we should unify the type signatures of dspy.Retriever + passages = [passages] + return passages @@ -38,9 +53,31 @@ def retrieveRerankEnsemble(queries: list[str], k: int,**kwargs) -> list[str]: passages_cs_scores[idx], ] + passages = [(np.average(score), text) for text, score in passages.items()] return [text for _, text in sorted(passages, reverse=True)[:k]] +def retrieveRerankEnsemblewithMetadata(queries: list[str], k: int, **kwargs) -> list[str]: + if not (dsp.settings.rm and dsp.settings.reranker): + raise AssertionError("Both RM and Reranker are needed to retrieve & re-rank.") + queries = [q for q in queries if q] + all_queries_passages = [] + for query in queries: + passages = [] + retrieved_passages = dsp.settings.rm(query, k=k * 3, **kwargs) + passages_cs_scores = dsp.settings.reranker( + query, passages=[psg["long_text"] for psg in retrieved_passages], + ) + for idx in np.argsort(passages_cs_scores)[::-1][:k]: + curr_passage = retrieved_passages[idx] + curr_passage["rerank_score"] = passages_cs_scores[idx] + passages.append(curr_passage) + all_queries_passages.append(passages) + if len(queries) == 1: + return all_queries_passages[0] + else: + return all_queries_passages + def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) -> list[str]: """Retrieves passages from the RM for each query in queries and returns the top k passages @@ -50,7 +87,6 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) raise AssertionError("No RM is loaded.") if dsp.settings.reranker: return retrieveRerankEnsemble(queries, k, **kwargs) - queries = [q for q in queries if q] if len(queries) == 1: @@ -68,4 +104,43 @@ def retrieveEnsemble(queries: list[str], k: int, by_prob: bool = True,**kwargs) passages = sorted(passages, reverse=True)[:k] passages = [text for _, text in passages] + return passages + +def retrieveEnsemblewithMetadata( + queries: list[str], k: int, by_prob: bool = True, **kwargs, +) -> list[str]: + """Retrieves passages from the RM for each query in queries and returns the top k passages + based on the probability or score. + """ + + if not dsp.settings.rm: + raise AssertionError("No RM is loaded.") + if not dsp.settings.reranker: + return retrieveRerankEnsemblewithMetadata(queries=queries,k=k) + + queries = [q for q in queries if q] + + if len(queries) == 1: + return retrieve(queries[0], k) + all_queries_passages = [] + for q in queries: + passages = {} + retrieved_passages = dsp.settings.rm(q, k=k * 3, **kwargs) + for idx, psg in enumerate(retrieved_passages): + if by_prob: + passages[(idx, psg.long_text)] = ( + passages.get(psg.long_text, 0.0) + psg.prob + ) + else: + passages[(idx, psg.long_text)] = ( + passages.get(psg.long_text, 0.0) + psg.score + ) + retrieved_passages[idx]["tracking_idx"] = idx + passages = sorted(passages.items(), key=lambda item: item[1])[:k] + req_indices = [psg[0][0] for psg in passages] + passages = [ + rp for rp in retrieved_passages if rp.get("tracking_idx") in req_indices + ] + all_queries_passages.append(passages) + return all_queries_passages \ No newline at end of file diff --git a/dspy/__init__.py b/dspy/__init__.py index 2c6c1d7d45..9602b1f60b 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -20,6 +20,8 @@ Databricks = dsp.Databricks Cohere = dsp.Cohere ColBERTv2 = dsp.ColBERTv2 +ColBERTv2RerankerLocal = dsp.ColBERTv2RerankerLocal +ColBERTv2RetrieverLocal = dsp.ColBERTv2RetrieverLocal Pyserini = dsp.PyseriniRetriever Clarifai = dsp.ClarifaiLLM CloudflareAI = dsp.CloudflareAI diff --git a/dspy/retrieve/__init__.py b/dspy/retrieve/__init__.py index 1d1f9e8b7d..2f699c23ad 100644 --- a/dspy/retrieve/__init__.py +++ b/dspy/retrieve/__init__.py @@ -1 +1 @@ -from .retrieve import Retrieve \ No newline at end of file +from .retrieve import Retrieve, RetrieveThenRerank \ No newline at end of file diff --git a/dspy/retrieve/retrieve.py b/dspy/retrieve/retrieve.py index 7f026e2aa5..6c50e2bbf5 100644 --- a/dspy/retrieve/retrieve.py +++ b/dspy/retrieve/retrieve.py @@ -1,11 +1,21 @@ import random -from typing import List, Optional, Union +from typing import Dict, List, Optional, Union import dsp from dspy.predict.parameter import Parameter from dspy.primitives.prediction import Prediction +def single_query_passage(passages): + passages_dict = {key: [] for key in list(passages[0].keys())} + for docs in passages: + for key, value in docs.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + return Prediction(**passages_dict) + + class Retrieve(Parameter): name = "Search" input_variable = "query" @@ -29,14 +39,120 @@ def load_state(self, state): def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) - def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None,**kwargs) -> Prediction: - queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries - queries = [query.strip().split('\n')[0].strip() for query in queries] + def forward( + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + by_prob: bool = True, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: + # queries = [query_or_queries] if isinstance(query_or_queries, str) else query_or_queries + # queries = [query.strip().split('\n')[0].strip() for query in queries] + + # # print(queries) + # # TODO: Consider removing any quote-like markers that surround the query too. + # k = k if k is not None else self.k + # passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) + # return Prediction(passages=passages) + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + queries = [query.strip().split("\n")[0].strip() for query in queries] # print(queries) # TODO: Consider removing any quote-like markers that surround the query too. k = k if k is not None else self.k - passages = dsp.retrieveEnsemble(queries, k=k,**kwargs) - return Prediction(passages=passages) + if not with_metadata: + passages = dsp.retrieveEnsemble(queries, k=k, by_prob=by_prob, **kwargs) + return Prediction(passages=passages) + else: + passages = dsp.retrieveEnsemblewithMetadata( + queries, k=k, by_prob=by_prob, **kwargs + ) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = { + key: [] + for key in list(query_passages[0].keys()) + if key != "tracking_idx" + } + for psg in query_passages: + for key, value in psg.items(): + if key == "tracking_idx": + continue + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + # passages dict will contain {"long_text":long_text_list,"metadatas";metadatas_list...} + return single_query_passage(passages=passages) + # TODO: Consider doing Prediction.from_completions with the individual sets of passages (per query) too. + + +class RetrieveThenRerank(Parameter): + name = "Search" + input_variable = "query" + desc = "takes a search query and returns one or more potentially relevant passages followed by reranking from a corpus" + + def __init__(self, k=3): + self.stage = random.randbytes(8).hex() + self.k = k + + def reset(self): + pass + + def dump_state(self): + state_keys = ["k"] + return {k: getattr(self, k) for k in state_keys} + + def load_state(self, state): + for name, value in state.items(): + setattr(self, name, value) + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + def forward( + self, + query_or_queries: Union[str, List[str]], + k: Optional[int] = None, + with_metadata: bool = False, + **kwargs, + ) -> Union[List[str], Prediction, List[Prediction]]: + queries = ( + [query_or_queries] + if isinstance(query_or_queries, str) + else query_or_queries + ) + queries = [query.strip().split("\n")[0].strip() for query in queries] + + # print(queries) + # TODO: Consider removing any quote-like markers that surround the query too. + k = k if k is not None else self.k + if not with_metadata: + passages = dsp.retrieveRerankEnsemble(queries, k=k, **kwargs) + return passages + else: + passages = dsp.retrieveRerankEnsemblewithMetadata(queries, k=k, **kwargs) + if isinstance(passages[0], List): + pred_returns = [] + for query_passages in passages: + passages_dict = {key: [] for key in list(query_passages[0].keys())} + for docs in query_passages: + for key, value in docs.items(): + passages_dict[key].append(value) + if "long_text" in passages_dict: + passages_dict["passages"] = passages_dict.pop("long_text") + + pred_returns.append(Prediction(**passages_dict)) + return pred_returns + elif isinstance(passages[0], Dict): + return single_query_passage(passages=passages) diff --git a/examples/integrations/colbert/colbert_local.ipynb b/examples/integrations/colbert/colbert_local.ipynb new file mode 100644 index 0000000000..f5eb881a23 --- /dev/null +++ b/examples/integrations/colbert/colbert_local.ipynb @@ -0,0 +1,404 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## IN THIS NOTEBOOK, WE WILL EXPLORE THE COLBERT AS A RERANKER AND RETRIEVER IN LOCAL MODE. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "* If you want to build a server from your colbert local index, please refer [here](https://github.com/stanford-futuredata/ColBERT/blob/main/server.py)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from colbert.infra.config import ColBERTConfig" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "# You can set this environment variable for debugging purposes\n", + "os.environ['COLBERT_LOAD_TORCH_EXTENSION_VERBOSE'] = \"True\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Let's review the colbert config class" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# You can view the different attributes of the colbert config by uncommenting cell below\n", + "# for k,v in ColBERTConfig().__dict__.items():\n", + "# print(f\"{k} --> {v}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "passages = [\"It's a piece of cake.\", \"Don't put off until tomorrow what you can do today.\", 'To kill two birds with one stone.', 'Actions speak louder than words.', 'Honesty is the best policy.', 'If you want something done right, do it yourself.', 'The best things in life are free.', \"Don't count your chickens before they hatch.\", 'She sells seashells by the seashore.', 'Practice makes perfect.', \"Where there's a will, there's a way.\", 'Absence makes the heart grow fonder.', 'When the going gets tough, the tough get going.', 'A journey of a thousand miles begins with a single step.', \"You can't have your cake and eat it too.\", \"If you can't beat them, join them.\", 'Keep your friends close and your enemies closer.', \"Don't put all your eggs in one basket.\", \"All's fair in love and war.\", 'Every dog has its day.', 'All good things must come to an end.', 'Once bitten, twice shy.', \"The apple doesn't fall far from the tree.\", 'A penny saved is a penny earned.', \"Don't bite the hand that feeds you.\", 'You reap what you sow.', 'An apple a day keeps the doctor away.', \"One man's trash is another man's treasure.\", 'The squeaky wheel gets the grease.', 'A picture is worth a thousand words.', 'Fortune favors the bold.', 'Practice what you preach.', 'A watched pot never boils.', 'No pain, no gain.', \"You can't make an omelet without breaking eggs.\", \"There's no place like home.\", 'Ask and you shall receive.', 'Let sleeping dogs lie.', 'If the shoe fits, wear it.', 'Every cloud has a silver lining.', 'Look before you leap.', 'The more, the merrier.', 'The grass is always greener on the other side.', 'Beauty is only skin deep.', \"Two wrongs don't make a right.\", 'Beauty is in the eye of the beholder.', 'Necessity is the mother of invention.', 'Out of sight, out of mind.', 'Patience is a virtue.', 'Curiosity killed the cat.', \"If at first you don't succeed, try, try again.\", \"Beggars can't be choosers.\", 'Too many cooks spoil the broth.', 'Easy come, easy go.', \"Don't cry over spilled milk.\", \"There's no such thing as a free lunch.\", 'A bird in the hand is worth two in the bush.', 'Good things come to those who wait.', 'The quick brown fox jumps over the lazy dog.', 'It takes two to tango.', 'A friend in need is a friend indeed.', 'Like father, like son.', 'Let bygones be bygones.', 'Kill two birds with one stone.', 'A penny for your thoughts.', 'I am the master of my fate, I am the captain of my soul.', 'The pen is mightier than the sword.', 'When in Rome, do as the Romans do.', \"Rome wasn't built in a day.\", \"You can't judge a book by its cover.\", \"It's raining cats and dogs.\", 'Make hay while the sun shines.', \"It's better to be safe than sorry.\", 'The early bird catches the worm.', 'To be or not to be, that is the question.', 'Better late than never.']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## This tutorial is running from the `examples/integrations/tutorials folder`, hence we need to add the system path for dspy\n", + "\n", + "* If you have installed the dspy package, then you don't need to run the below cell" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "sys.path.append(\"../../..\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## COLBERT AS RETRIEVER" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import dspy\n", + "colbert_config = ColBERTConfig()\n", + "colbert_config.index_name = \"Colbert-RM\"\n", + "colbert_config.experiment = \"Colbert-Experiment\"\n", + "colbert_config.checkpoint = \"colbert-ir/colbertv2.0\"\n", + "colbert_retriever = dspy.ColBERTv2RetrieverLocal(\n", + " passages = passages,load_only=False,\n", + " colbert_config=colbert_config\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "#CONFIGURE COLBERT IN DSPY\n", + "dspy.settings.configure(rm=colbert_retriever)\n", + "\n", + "retrieved_docs = dspy.Retrieve(k=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "#> QueryTokenizer.tensorize(batch_text[0], batch_background[0], bsize) ==\n", + "#> Input: . What is the meaning of life?, \t\t True, \t\t None\n", + "#> Output IDs: torch.Size([32]), tensor([ 101, 1, 2054, 2003, 1996, 3574, 1997, 2166, 1029, 102, 103, 103,\n", + " 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103, 103,\n", + " 103, 103, 103, 103, 103, 103, 103, 103], device='cuda:0')\n", + "#> Output Mask: torch.Size([32]), tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0')\n", + "\n" + ] + } + ], + "source": [ + "pred = retrieved_docs(\n", + " \"What is the meaning of life?\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 6, 47, 74, 48],\n", + " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", + ")" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "DeprecationWarning: 'dspy.Retrieve' for reranking has been deprecated, please use dspy.RetrieveThenRerank. The reranking is ignored here. In the future this will raise an error.\n" + ] + } + ], + "source": [ + "multiple_pred = retrieved_docs(\n", + " [\"What is the meaning of life?\",\"Meaning of pain?\"],by_prob=False\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 6, 47, 74, 48],\n", + " passages=['No pain, no gain.', 'The best things in life are free.', 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Patience is a virtue.']\n", + " ),\n", + " Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[16, 0, 47, 74, 26],\n", + " passages=['Keep your friends close and your enemies closer.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'An apple a day keeps the doctor away.']\n", + " )]" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "multiple_pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## COLBERT AS RERANKER" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "colbert_config = ColBERTConfig()\n", + "colbert_config.index_name = 'colbert-ir-index'\n", + "colbert_reranker = dspy.ColBERTv2RerankerLocal(\n", + " checkpoint='colbert-ir/colbertv2.0',colbert_config=colbert_config)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "dspy.settings.configure(rm=colbert_retriever,reranker=colbert_reranker)\n", + "\n", + "retrieve_rerank = dspy.RetrieveThenRerank(k=5)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "pred = retrieve_rerank(\n", + " [\"What is the meaning of life?\",\"Meaning of pain?\"]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[6, 48, 74, 47, 33],\n", + " rerank_score=[15.8359375, 14.2109375, 12.5703125, 11.7890625, 9.1796875],\n", + " passages=['The best things in life are free.', 'Patience is a virtue.', 'To be or not to be, that is the question.', 'Out of sight, out of mind.', 'No pain, no gain.']\n", + " ),\n", + " Prediction(\n", + " score=[nan, nan, nan, nan, nan],\n", + " pid=[33, 0, 47, 74, 16],\n", + " rerank_score=[19.828125, 12.2890625, 11.171875, 9.09375, 6.8984375],\n", + " passages=['No pain, no gain.', \"It's a piece of cake.\", 'Out of sight, out of mind.', 'To be or not to be, that is the question.', 'Keep your friends close and your enemies closer.']\n", + " )]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## YOU CAN ALSO COLBERT RERANKER AS STANDALONE MODEL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install tabulate" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import tabulate\n", + "\n", + "scores_arr = colbert_reranker(\n", + " \"What is the meaning of life and pain?\",\n", + " # Pass a subset of passages\n", + " passages[:10]\n", + ")\n", + "\n", + "tabulate_data = []\n", + "for idx in np.argsort(scores_arr)[::-1]:\n", + " # print(f\"Passage = {passages[idx]} --> Score = {scores_arr[idx]}\")\n", + " tabulate_data.append([passages[idx],scores_arr[idx]])\n", + "\n", + "table = tabulate.tabulate(tabulate_data,tablefmt=\"html\",headers={'sentence','score'})" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
score sentence
The best things in life are free. 12.5156
It's a piece of cake. 10
Practice makes perfect. 8.27344
Honesty is the best policy. 7.57422
To kill two birds with one stone. 7.51953
Actions speak louder than words. 7.05469
If you want something done right, do it yourself. 6.52344
Don't put off until tomorrow what you can do today. 3.78711
She sells seashells by the seashore. 2.77148
Don't count your chickens before they hatch. 1.82227
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from IPython.display import HTML, display\n", + "display(HTML(table))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}