In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
from dotenv import load_dotenv
import os
from pymilvus import MilvusClient
from pymilvus.model.hybrid import BGEM3EmbeddingFunction

In [2]:
class Model:
    # _model_id = "meta-llama/Meta-Llama-3-8B"
    _model_id = "nvidia/Llama3-ChatQA-1.5-8B"
    _device = torch.cuda.current_device()

    def __init__(self):
        load_dotenv()
        self._access_token = os.getenv("ACCESS_TOKEN")
        self._tokenizer = self._load_tokenizer()
        self._vector_embeddings = self._load_vector_embeddings()
        self._model = self._load_model()
        self._milvus_client = MilvusClient("milvus_demo.db")
        self._collection_name = "MTG_collection"
        if not self._milvus_client.has_collection(self._collection_name):
            sample_text = "This is a sample text to determine embedding dimension"
            self._milvus_client.create_collection(
                collection_name=self._collection_name,
                dimension=self._vector_embeddings.dim["dense"],
                metric_type="L2",
            )

            print("Inserting Data to Vector database")
            sentence1 = "Malyta is the best card in Modern Horizons 3"
            sentence2 = "Hyidralit is the best card in Modern Horizons 4" 
            sentence3 = "Gafagl is the best card in Pioneer Masters"
            self._milvus_client.insert(
                collection_name=self._collection_name,
                data=[
                    {"id": 0, "text": sentence1, "vector": self.get_embedding(sentence1).tolist()},
                    {"id": 1, "text": sentence2, "vector": self.get_embedding(sentence2).tolist()},
                    {"id": 2, "text": sentence3, "vector": self.get_embedding(sentence3).tolist()},
                ],
            )

    def get_embedding(self, text):
        if text[-1] == ".":
            text = text[:-1]
        return self._vector_embeddings([text])["dense"][0]

    def retrieve_similar_docs(self, query, top_k=5):
        query_embedding = self.get_embedding(query)
        search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
        results = self._milvus_client.search(collection_name=model._collection_name, data=[query_embedding.tolist()], search_params=search_params, limit=5, output_fields=["text"])
        return [entity["entity"]["text"] for entity in results[0]]

    def _load_model(self):
        assert self._tokenizer is not None, "Tokenizer must be initialized first"

        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
            bnb_4bit_compute_dtype=torch.bfloat16,
        )
        model = AutoModelForCausalLM.from_pretrained(
            self._model_id,
            quantization_config=bnb_config,
            device_map={"": self._device},
            token=self._access_token,
        )
        model.config.pad_token_id = model.config.eos_token_id
        model.generation_config.pad_token_id = self._tokenizer.pad_token_id

        return model

    def _load_tokenizer(self):
        tokenizer = AutoTokenizer.from_pretrained(
            self._model_id,
            token=self._access_token,
        )

        tokenizer.padding_side = "left"

        # Define PAD Token = EOS Token
        tokenizer.pad_token = tokenizer.eos_token
        return tokenizer

    def _load_vector_embeddings(self):
        # Take care of use case with CPU
        # return BGEM3EmbeddingFunction(model_name='BAAI/bge-base-en-v1.5', use_fp16=False, device="cpu", return_sparse=False)
        return BGEM3EmbeddingFunction(model_name='BAAI/bge-base-en-v1.5', use_fp16=True, device="cuda", return_sparse=False)

    def answer(self, query: str) -> str:
        # Pre and post processing taken from: https://towardsdatascience.com/how-to-build-a-local-open-source-llm-chatbot-with-rag-f01f73e2a131
        retrieved_docs = model.retrieve_similar_docs(query)
        context = "\n".join(retrieved_docs)  # Simplification; you might want to process this differently
        prompt = f"""Using the information contained in the context, give a detailed answer to the question.
                    Context: {context}.
                    Question: {query}"""
        chat = [{"role": "user", "content": prompt}]
        formatted_prompt = self._tokenizer.apply_chat_template(
            chat,
            tokenize=False,
            add_generation_prompt=True,
        )
        inputs = self._tokenizer.encode(formatted_prompt, add_special_tokens=False, return_tensors="pt")
        outputs = self._model.generate(inputs, max_length=200)
        response = self._tokenizer.decode(outputs[0], skip_special_tokens=False)
        response = response[len(formatted_prompt) :]  # remove input prompt from reponse
        response = response.replace("<eos>", "")  # remove eos token
        return response


In [3]:
model = Model()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Fetching 14 files:   0%|          | 0/14 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
model.answer("What is the best card in Modern Horizons 3?")

The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.


'<|begin_of_text|> Malyta is the best card in Modern Horizons 3<|end_of_text|>'

In [5]:
model.answer("What is the best card in Modern Horizons 4?")

'<|begin_of_text|> Hyidralit is the best card in Modern Horizons 4<|end_of_text|>'

In [6]:
model.answer("What is the best card in Pioneer Masters?")

'<|begin_of_text|> Gafagl is the best card in Pioneer Masters<|end_of_text|>'