Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal for an inference class #58

Closed
JamesDeAntonis opened this issue Jul 19, 2021 · 4 comments
Closed

Proposal for an inference class #58

JamesDeAntonis opened this issue Jul 19, 2021 · 4 comments

Comments

@JamesDeAntonis
Copy link

JamesDeAntonis commented Jul 19, 2021

Hi all,

Here is the code for the class I'm currently using to wrap colbert stuff. (sorry if there are some errors; I tried to delete the extra internal code that's only relevant to my team). Maybe something like this could be merged into the actual repo?

Jamie

from dataclasses import dataclass, field
import os

import torch
from transformers.modeling_utils import no_init_weights
from transformers import BertConfig

from colbert.modeling.inference import ModelInference
from colbert.ranking.rankers import Ranker
from colbert.modeling.colbert import ColBERT

@dataclass
class RankerArgs:
    index_path: str = field(metadata={"help": "path to doclens files"})
    faiss_index_path: str = field(metadata={"help": "path to faiss indices (often the same place as `index_path`"})
    nprobe: int = field(metadata={"help": "the number of clusters to visit during faiss search"})
    part_range: range = field(init=False)

    def __post_init__(self):
        self.part_range = None

class ColbertRetriever:
    def __init__(
        self, 
        colbert_model_path: str, 
        amp: bool=False,
        index_path,
        faiss_index_path,
        nprobe: int = 10,
        faiss_depth: int = 1024,
    ):

        inference = ModelInference(
            ColbertModel.from_saved_model(colbert_model_path), amp=amp
        )
        ranker_args = RankerArgs(index_path, faiss_index_path, nprobe)
        self.ranker = Ranker(ranker_args, inference, faiss_depth=faiss_depth)

    def retrieve_and_rerank(self, query: str, k: int):
        Q = self.ranker.encode([query])  # encode the query
        pids, scores = self.ranker.rank(Q)  # rank
        
        assert k <= len(pids)

        pids = pids[:k]
        scores = scores[:k]
        return pids, scores

class ColbertModel(ColBERT):
    @classmethod
    def from_saved_model(cls, model_path: str) -> "ColbertModel":
        """
        load colbert from a saved model

        Parameters
        ----------
        model_path : str
            the full path to a file containing a json with state_dict and other things

        Returns
        -------
        a colbert model
        """

        model_dict = torch.load(model_path)

        config = BertConfig()

        args = model_dict["arguments"]

        with no_init_weights():
            model = cls(
                config,
                query_maxlen=args["query_maxlen"],
                doc_maxlen=args["doc_maxlen"],
                mask_punctuation=args["mask_punctuation"],
                dim=args["dim"],
                similarity_metric=args["similarity"],
            )

        cls._load_state_dict_into_model(
            model, model_dict["model_state_dict"], model_path
        )

        model.eval()
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model.to(device)
        return model
@okhat
Copy link
Collaborator

okhat commented Jul 23, 2021

Many thanks for this, Jamie! Looks pretty cool. Trying to wrap up a v0.3 of ColBERT soon; will make sure to integrate something based on your classes for inference.

@JamesDeAntonis
Copy link
Author

Sounds good! Let me know if you have any thoughts you want to discuss

@okhat
Copy link
Collaborator

okhat commented Jul 27, 2021

@JamesDeAntonis A beta version is out with aggressive compression: https://github.com/stanford-futuredata/ColBERT/tree/binarization

Thought you might be interested.

@okhat
Copy link
Collaborator

okhat commented Jul 28, 2021

Your code is pretty good, btw. [Many thanks!] Testing it out tomorrow from Jupyter. Hopefully I can merge it in very soon.

@okhat okhat closed this as completed Sep 19, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants