# Using MLC-LLM for RAG in Python

### Load model from HF

In [1]:
from transformers import AutoModel, AutoTokenizer

snowflake_model_path = "/Users/cfruan/Documents/models/snowflake-arctic-embed-m"

tokenizer = AutoTokenizer.from_pretrained(snowflake_model_path)
model = AutoModel.from_pretrained(snowflake_model_path, add_pooling_layer=False)
model.eval()

  from .autonotebook import tqdm as notebook_tqdm


BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

### Compare input tokenization

In [2]:
query_prefix = "Represent this sentence for searching relevant passages: "
queries = ["what is snowflake?", "Where can I get the best tacos?"]
queries_with_prefix = ["{}{}".format(query_prefix, i) for i in queries]
query_tokens = tokenizer(
    queries_with_prefix, padding=True, truncation=True, return_tensors="pt", max_length=512
)

documents = ["The Data Cloud!", "Mexico City of Course!"]
document_tokens = tokenizer(
    documents, padding=True, truncation=True, return_tensors="pt", max_length=512
)

In [26]:
print("queries_with_prefix[0]: ", queries_with_prefix[0])
print("documents[0]: ", documents[0])
print("query_tokens[input_ids].shape: ", query_tokens["input_ids"].shape)
print("document_tokens[input_ids].shape: ", document_tokens["input_ids"].shape)

# We see the last 3 entries are zero because it is padding due to batching
print("query_tokens[input_ids][0]: ", query_tokens["input_ids"][0])

queries_with_prefix[0]:  Represent this sentence for searching relevant passages: what is snowflake?
documents[0]:  The Data Cloud!
query_tokens[input_ids].shape:  torch.Size([2, 19])
document_tokens[input_ids].shape:  torch.Size([2, 7])
query_tokens[input_ids][0]:  tensor([  101,  5050,  2023,  6251,  2005,  6575,  7882, 13768,  1024,  2054,
         2003,  4586, 10258, 13808,  1029,   102,     0,     0,     0])


In [15]:
from mlc_llm.embeddings.embeddings import MLCEmbeddings

mlc_embeddings = MLCEmbeddings(
    "/Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC",
    "/Users/cfruan/Documents/mlc-llm/dist/libs/snowflake-arctic-embed-m-q0f32-metal.so",
    device="metal:0",
    # debug_dir="/Users/cfruan/Documents/mlc-llm-repos/mlc-llm-head/debug",
)
mlc_queries = ["[CLS] " + query + " [SEP]" for query in queries_with_prefix]
mlc_documents = ["[CLS] " + document + " [SEP]" for document in documents]
mlc_tokens = mlc_embeddings._tokenize_queries(mlc_queries)

print("mlc_queries[0]", mlc_queries[0])
print("mlc_documents[0]", mlc_documents[0])

[2024-05-01 10:11:17] INFO chat_module.py:379: Using model folder: /Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC
[2024-05-01 10:11:17] INFO chat_module.py:380: Using mlc chat config: /Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC/mlc-chat-config.json


mlc_queries[0] [CLS] Represent this sentence for searching relevant passages: what is snowflake? [SEP]
mlc_documents[0] [CLS] The Data Cloud! [SEP]


In [16]:
# Assert tokenization of input is the same
import numpy as np

for i in range(len(mlc_tokens[0])):
    np.testing.assert_array_equal(mlc_tokens[0][i], query_tokens["input_ids"][i].numpy())

### Compare embeddings output

In [29]:
# Output is of shape (batch_size, longest_seq_len, hidden_size)
query_embeddings = model(**query_tokens)
print("query_embeddings[0].shape: ", query_embeddings[0].shape)
print(query_embeddings[0])

query_embeddings[0].shape:  torch.Size([2, 19, 768])
tensor([[[ 0.3311,  0.8075,  0.1499,  ...,  0.6162, -0.0541,  0.2450],
         [ 0.3530,  1.0479,  0.0503,  ...,  0.9943, -0.1670,  0.3151],
         [ 0.2639,  1.1741, -0.1657,  ...,  1.0260,  0.0698,  0.0333],
         ...,
         [ 0.1866,  0.8087,  0.0610,  ...,  0.6060, -0.0143,  0.5403],
         [ 0.2869,  0.8392,  0.1528,  ...,  0.6220, -0.0822,  0.4867],
         [ 0.1527,  0.8309,  0.1612,  ...,  0.5274, -0.0802,  0.5454]],

        [[-0.1421, -0.0361,  0.6161,  ...,  0.2524,  0.0108,  0.5810],
         [ 0.1238,  0.0119,  0.4495,  ...,  0.2463, -0.1506,  0.9454],
         [-0.0113, -0.0466,  0.2424,  ...,  0.5182, -0.1246,  1.1071],
         ...,
         [-0.1490,  0.3295,  0.2802,  ...,  0.2021, -0.1088,  1.0623],
         [-0.2999, -0.1606,  0.3321,  ...,  0.7638, -0.0594,  1.0833],
         [ 0.0363,  0.0367,  0.9930,  ...,  0.2296, -0.0787,  0.8901]]],
       grad_fn=<NativeLayerNormBackward0>)


In [30]:
mlc_query_embeds = mlc_embeddings.embed(mlc_queries).numpy()
print("mlc_query_embeds.shape: ", mlc_query_embeds.shape)
np.testing.assert_array_almost_equal(
    query_embeddings[0].detach().numpy(), mlc_query_embeds, decimal=0.01
)

mlc_query_embeds.shape:  (2, 19, 768)


### Compare score computation

In [34]:
mlc_query_embeds = mlc_embeddings.embed(mlc_queries).numpy()
mlc_document_embeds = mlc_embeddings.embed(mlc_documents).numpy()

mlc_query_embeds = mlc_query_embeds[:, 0]
mlc_document_embeds = mlc_document_embeds[:, 0]

mlc_query_embeds = mlc_query_embeds / np.linalg.norm(mlc_query_embeds, axis=1, keepdims=True)
mlc_document_embeds = mlc_document_embeds / np.linalg.norm(
    mlc_document_embeds, axis=1, keepdims=True
)

scores = np.dot(mlc_query_embeds, mlc_document_embeds.T)

for query, query_scores in zip(queries, scores):
    doc_score_pairs = sorted(zip(documents, query_scores), key=lambda x: x[1], reverse=True)
    print("Query: {}".format(query))
    for doc, score in doc_score_pairs:
        print("{}: {}".format(score, doc))

Query: what is snowflake?
0.2747487425804138: The Data Cloud!
0.19997990131378174: Mexico City of Course!
Query: Where can I get the best tacos?
0.29974812269210815: Mexico City of Course!
0.2344069629907608: The Data Cloud!


In [35]:
import torch

with torch.no_grad():
    query_embeddings = model(**query_tokens)[0][:, 0]
    doument_embeddings = model(**document_tokens)[0][:, 0]


# normalize embeddings
query_embeddings = torch.nn.functional.normalize(query_embeddings, p=2, dim=1)
doument_embeddings = torch.nn.functional.normalize(doument_embeddings, p=2, dim=1)

scores = torch.mm(query_embeddings, doument_embeddings.transpose(0, 1))
for query, query_scores in zip(queries, scores):
    doc_score_pairs = list(zip(documents, query_scores))
    doc_score_pairs = sorted(doc_score_pairs, key=lambda x: x[1], reverse=True)
    # Output passages & scores
    print("Query:", query)
    for document, score in doc_score_pairs:
        print(score, document)

Query: what is snowflake?
tensor(0.2747) The Data Cloud!
tensor(0.2000) Mexico City of Course!
Query: Where can I get the best tacos?
tensor(0.2997) Mexico City of Course!
tensor(0.2344) The Data Cloud!


### Chroma VectorDB usage

Follows https://github.com/langchain-ai/langgraph/blob/main/examples/rag/langgraph_rag_agent_llama3_local.ipynb

In [37]:
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.embeddings import Embeddings
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import Chroma

from typing import List

In [38]:
class ArcticEmbeddings(Embeddings):
    def embed_documents(self, texts: List[str]) -> List[List[float]]:
        parsed_texts = ["[CLS]" + text + "[SEP]" for text in texts]
        embed_tokens = mlc_embeddings.embed(parsed_texts).numpy()[:, 0]
        embed_tokens = embed_tokens / np.linalg.norm(embed_tokens, axis=1, keepdims=True)
        return embed_tokens.tolist()

    def embed_query(self, text: str) -> List[float]:
        parsed_text = "[CLS] Represent this sentence for searching relevant passages: " + text + "[SEP]"
        embed_tokens = mlc_embeddings.embed([parsed_text]).numpy()[:, 0]
        embed_tokens = embed_tokens / np.linalg.norm(embed_tokens, axis=1, keepdims=True)
        return embed_tokens.tolist()[0]

In [43]:
chroma_client = Chroma(
    "mlc_rag",
    ArcticEmbeddings(),
    "/Users/cfruan/Documents/mlc-llm/rag",
)

urls = [
    "https://lilianweng.github.io/posts/2023-06-23-agent/",
    "https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/",
    "https://lilianweng.github.io/posts/2023-10-25-adv-attack-llm/",
]

docs = [WebBaseLoader(url).load() for url in urls]
docs_list = [item for sublist in docs for item in sublist]

text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=250, chunk_overlap=0
)
doc_splits = text_splitter.split_documents(docs_list)

chunk_size = 20
for i in range(0, len(doc_splits), chunk_size):
    print("Adding documents {} to {}".format(i, i + chunk_size))
    chroma_client.add_documents(doc_splits[i : i + chunk_size])
retriever = chroma_client.as_retriever()

[2024-05-01 10:34:40] INFO segment.py:189: Collection mlc_rag is not created.
[2024-05-01 10:34:40] INFO web_base.py:105: fake_useragent not found, using default user agent.To get a realistic header for requests, `pip install fake_useragent`.
[2024-05-01 10:34:40] INFO web_base.py:105: fake_useragent not found, using default user agent.To get a realistic header for requests, `pip install fake_useragent`.
[2024-05-01 10:34:40] INFO web_base.py:105: fake_useragent not found, using default user agent.To get a realistic header for requests, `pip install fake_useragent`.


Adding documents 0 to 20
Adding documents 20 to 40
Adding documents 40 to 60
Adding documents 60 to 80
Adding documents 80 to 100
Adding documents 100 to 120
Adding documents 120 to 140
Adding documents 140 to 160
Adding documents 160 to 180
Adding documents 180 to 200


In [44]:
question = "agent memory"
docs = retriever.invoke(question)
doc_txt = docs[1].page_content
doc_txt

'Sensory memory as learning embedding representations for raw inputs, including text, image or other modalities;\nShort-term memory as in-context learning. It is short and finite, as it is restricted by the finite context window length of Transformer.\nLong-term memory as the external vector store that the agent can attend to at query time, accessible via fast retrieval.\n\nMaximum Inner Product Search (MIPS)#\nThe external memory can alleviate the restriction of finite attention span.  A standard practice is to save the embedding representation of information into a vector store database that can support fast maximum inner-product search (MIPS). To optimize the retrieval speed, the common choice is the approximate nearest neighbors (ANN)\u200b algorithm to return approximately top k nearest neighbors to trade off a little accuracy lost for a huge speedup.\nA couple common choices of ANN algorithms for fast MIPS:'