In [5]:
import sys
print(sys.path)

['/Users/cfruan/Documents/mlc-llm', '/Users/cfruan/miniconda3/envs/mlc-chat-venv/lib/python311.zip', '/Users/cfruan/miniconda3/envs/mlc-chat-venv/lib/python3.11', '/Users/cfruan/miniconda3/envs/mlc-chat-venv/lib/python3.11/lib-dynload', '', '/Users/cfruan/miniconda3/envs/mlc-chat-venv/lib/python3.11/site-packages', '/Users/cfruan/Documents/tvm/python', '/Users/cfruan/Documents/mlc-llm/python']


In [6]:
sys.path.append("/Users/cfruan/Documents/tvm/python")
sys.path.append("/Users/cfruan/Documents/mlc-llm/python")

In [7]:
from transformers import AutoModel, AutoTokenizer
from sentence_transformers import SentenceTransformer, quantize_embeddings
from mlc_llm.embeddings.embeddings import MLCEmbeddings
import tvm
import torch
import numpy as np

### Load model from HF

In [8]:
snowflake_model_path = "/Users/cfruan/Documents/models/snowflake-arctic-embed-m"

tokenizer = AutoTokenizer.from_pretrained(snowflake_model_path)
model = AutoModel.from_pretrained(snowflake_model_path, add_pooling_layer=False)
model.eval()

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0-11): 12 x BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=False)
  

### Prepare inputs

In [9]:
query_prefix = "Represent this sentence for searching relevant passages: "
queries = ["what is snowflake?", "Where can I get the best tacos?"]
queries_with_prefix = ["{}{}".format(query_prefix, i) for i in queries]
query_tokens = tokenizer(
    queries_with_prefix, padding=True, truncation=True, return_tensors="pt", max_length=512
)

documents = ["The Data Cloud!", "Mexico City of Course!"]

mlc_queries = ["[CLS] " + query + " [SEP]" for query in queries_with_prefix]
mlc_documents = ["[CLS] " + document + " [SEP]" for document in documents]

### HF Embeddings Quantization

In [10]:
# Output is of shape (batch_size, longest_seq_len, hidden_size)
# We slice it to (batch_size, hidden_size)
query_embeddings_hf = model(**query_tokens)[0][:, 0].detach().numpy()
query_int8_hf = quantize_embeddings(query_embeddings_hf, "int8")
query_binary_hf = quantize_embeddings(query_embeddings_hf, "binary")



### MLC Embeddings Quantization

In [12]:
mlc_embeddings = MLCEmbeddings(
    "/Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC",
    "/Users/cfruan/Documents/mlc-llm/dist/libs/snowflake-arctic-embed-m-q0f32-metal.so",
    device="auto",
    # debug_dir="/Users/cfruan/Documents/mlc-llm-repos/mlc-llm-head/debug",
)

[2024-05-02 18:53:15] INFO chat_module.py:379: Using model folder: /Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC
[2024-05-02 18:53:15] INFO chat_module.py:380: Using mlc chat config: /Users/cfruan/Documents/mlc-llm/dist/snowflake-arctic-embed-m-q0f32-MLC/mlc-chat-config.json


In [16]:
# Output is of shape (batch_size, longest_seq_len, hidden_size)
# We slice it to (batch_size, hidden_size)
query_embeddings_mlc = mlc_embeddings.embed(mlc_queries).numpy()[:,0]
query_embeddings_mlc_tvm = tvm.nd.array(query_embeddings_mlc, device=mlc_embeddings.device)

query_int8_mlc = mlc_embeddings.quantize_embeddings(query_embeddings_mlc_tvm, "int8").numpy()
query_binary_mlc = mlc_embeddings.quantize_embeddings(query_embeddings_mlc_tvm, "binary").numpy()
query_binary_mlc = (np.packbits(query_binary_mlc).reshape(query_binary_mlc.shape[0], -1) - 128).astype(np.int8)

In [17]:
np.testing.assert_allclose(query_embeddings_mlc, query_embeddings_hf, rtol=0.01)
np.testing.assert_allclose(query_int8_mlc, query_int8_mlc)
np.testing.assert_allclose(query_binary_mlc, query_binary_hf)

In [14]:
# Alternatively in a single kernel
query_embeddings_mlc = mlc_embeddings.embed_binary(mlc_queries).numpy()
query_binary_mlc = (np.packbits(query_embeddings_mlc).reshape(query_embeddings_mlc.shape[0], -1) - 128).astype(np.int8)

In [15]:
np.testing.assert_allclose(query_binary_mlc, query_binary_hf)

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
