# Mistral Embed

The Mistral Embedding though not popular -- is an interesting candidate for experimenting with Binary Quantization because of it's multilingual capabilities in European languages e.g. English, French, German. Here, we use embedding created for English text though.

## Setting up the environment

Our dependencies are specifies in the pyproject.toml files which ships with this notebook. You can install them using poetry by running the following command in the terminal:

```bash
poetry install --no-root
```

In [1]:
import json
import os

import loguru
import numpy as np
import pandas as pd

# Code of your application, which uses environment variables (e.g. from `os.environ` or
# `os.getenv`) as if they came from the actual environment.
from datasets import load_dataset
from dotenv import load_dotenv
from qdrant_client import QdrantClient, models
from qdrant_client.models import PointStruct
from tqdm import tqdm

load_dotenv()  # take environment variables from .env.

logger = loguru.logger
logger.add("logs.log", format="{time} {level} {message}", level="INFO")

  from .autonotebook import tqdm as notebook_tqdm


1

In [None]:
dataset = load_dataset(
    "nirantk/dbpedia-entities-google-palm-gemini-embedding-001-100K",
    streaming=False,
    split="train",
)
dataset

In [None]:
dataset = dataset.remove_columns(column_names=["embedding"])

In [None]:
dataset = dataset.map(lambda x: {"combined_text": f"{x['title']}\n{x['text']}"})

In [None]:
dataset

In [None]:
from mistralai.client import MistralClient

api_key = os.environ["MISTRAL_API_KEY"]
client = MistralClient(api_key=api_key)

In [None]:
combined_text = dataset["combined_text"]

bs = 10
response_objects = []
for i in tqdm(range(0, len(combined_text), bs)):
    this_batch = list(combined_text[i : i + bs])
    embeddings_batch_response = client.embeddings(
        model="mistral-embed", input=this_batch
    )

    response_objects.append(embeddings_batch_response)

In [None]:
embedding_responses = [r.data for r in response_objects]
# flatten the list of lists
embedding_objects = [item for sublist in embedding_responses for item in sublist]
embeddings = [e.embedding for e in embedding_objects]

dataset = dataset.add_column("embedding", embeddings)

In [None]:
# dataset.push_to_hub("nirantk/dbpedia-entities-mistral-embeddings-100K")

# Use Dataset from Huggingface Hub

In [11]:
dataset = load_dataset(
    "nirantk/dbpedia-entities-mistral-embeddings-100K",
    streaming=False,
    split="train",
)
points = [
    {
        "id": i,
        "vector": embedding,
        "payload": {"text": data["text"], "title": data["title"]},
    }
    for i, (embedding, data) in enumerate(zip(dataset["embedding"], dataset))
]
points = [PointStruct(**point) for point in points]

Downloading readme: 100%|██████████| 432/432 [00:00<00:00, 2.39MB/s]
Downloading data: 100%|██████████| 126M/126M [00:30<00:00, 4.15MB/s] 
Downloading data: 100%|██████████| 126M/126M [00:25<00:00, 5.01MB/s] 
Generating train split: 100%|██████████| 100000/100000 [00:01<00:00, 71302.72 examples/s]


In [12]:
client = QdrantClient(
    url=os.getenv("QDRANT_URL"),
    api_key=os.getenv("QDRANT_API_KEY"),
    timeout=100,
)

# Setting up a Collection with Binary Quantization

In [13]:
collection_name = "mistral-embed"

In [15]:
client.recreate_collection(
    collection_name=collection_name,
    vectors_config=models.VectorParams(
        size=1024,
        distance=models.Distance.COSINE,
    ),
    optimizers_config=models.OptimizersConfigDiff(
        default_segment_number=5,
        indexing_threshold=0,
    ),
    quantization_config=models.BinaryQuantization(
        binary=models.BinaryQuantizationConfig(always_ram=True),
    ),
    shard_number=2,
)

True

In [16]:
collection_info = client.get_collection(collection_name=collection_name)

if collection_info.vectors_count == 0:
    logger.info("Collection is empty. Begin upsert.")
    bs = 1000  # Batch size
    for i in tqdm(range(0, len(points), bs)):
        slice_points = points[i : i + bs]  # Create a slice of bs points
        client.upsert(collection_name=collection_name, points=slice_points)
    

[32m2024-01-15 16:03:47.633[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m4[0m - [1mCollection is empty. Begin upsert.[0m
100%|██████████| 100/100 [04:07<00:00,  2.48s/it]


In [17]:
collection_info = client.get_collection(collection_name=collection_name)
collection_info.vectors_count

100000

### Turn on Indexing

In [18]:
client.update_collection(
    collection_name=f"{collection_name}",
    optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
)

True

In [21]:
client.search(
    collection_name=f"{collection_name}",
    query_vector=points[32].vector,
    search_params=models.SearchParams(
        quantization=models.QuantizationSearchParams(
            ignore=False,
            rescore=False,
            oversampling=2.0,
        ),
        exact=True,
    ),
    limit=5,
)

[ScoredPoint(id=32, version=0, score=1.0, payload={'text': 'Sobrassada (Catalan pronunciation: [soβɾəˈsaðə]; Spanish: sobrasada) is a raw, cured sausage from the Balearic Islands made with ground pork, paprika and salt and other spices. Sobrassada, along with botifarró are traditional Balearic sausage meat products prepared in the laborious but festive rites that still mark the autumn and winter pig slaughter known as a matança (in Spanish, matanza) in Majorca and Eivissa.', 'title': 'Sobrassada'}, vector=None, shard_key=None),
 ScoredPoint(id=98204, version=98, score=0.806023, payload={'text': 'Krakowska (pronounced /krəˈkɒvskə/ krə-KOV-skə) is a type of Polish sausage (kielbasa), usually served as a cold cut. The name derives from the city of Kraków (mediaeval capital of the Polish-Lithuanian Commonwealth till late 16th century). It is made from cuts of lean pork, seasoned with pepper, allspice, coriander, and garlic, packed into large casings, and smoked. English speaking countries 

In [22]:
ds = dataset.train_test_split(test_size=0.1, shuffle=True, seed=37)["test"]

In [None]:
oversampling_range = np.arange(1.0, 3.1, 1.0)
rescore_range = [True, False]


def parameterized_search(
    point,
    oversampling: float,
    rescore: bool,
    exact: bool,
    collection_name: str,
    ignore: bool = False,
    limit: int = 10,
):
    if exact:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(exact=exact),
            limit=limit,
        )
    else:
        return client.search(
            collection_name=collection_name,
            query_vector=point.vector,
            search_params=models.SearchParams(
                quantization=models.QuantizationSearchParams(
                    ignore=ignore,
                    rescore=rescore,
                    oversampling=oversampling,
                ),
                exact=exact,
            ),
            limit=limit,
        )


results = []
with open("results.json", "w+") as f:
    for point in tqdm(points[10:100]):
        # print(element.payload["text"])
        # print("Oversampling")

        ## Running Grid Search
        for oversampling in oversampling_range:
            for rescore in rescore_range:
                limit_range = [100, 50, 20, 10, 5, 1]
                for limit in limit_range:
                    try:
                        exact = parameterized_search(
                            point=point,
                            oversampling=oversampling,
                            rescore=rescore,
                            exact=True,
                            collection_name=collection_name,
                            limit=limit,
                        )
                        hnsw = parameterized_search(
                            point=point,
                            oversampling=oversampling,
                            rescore=rescore,
                            exact=False,
                            collection_name=collection_name,
                            limit=limit,
                        )
                    except Exception as e:
                        print(f"Skipping point: {point}\n{e}")
                        continue

                    exact_ids = [item.id for item in exact]
                    hnsw_ids = [item.id for item in hnsw]
                    # logger.info(f"Exact: {exact_ids}")
                    # logger.info(f"HNSW: {hnsw_ids}")

                    accuracy = len(set(exact_ids) & set(hnsw_ids)) / len(exact_ids)

                    if accuracy is None:
                        continue

                    result = {
                        "query_id": point.id,
                        "oversampling": oversampling,
                        "rescore": rescore,
                        "limit": limit,
                        "accuracy": accuracy,
                    }
                    f.write(json.dumps(result))
                    f.write("\n")
                    logger.info(result)

In [2]:
import pandas as pd

results = pd.read_json("results.json", lines=True)

In [3]:
# results.to_csv("results.csv", index=False)
average_accuracy = results[results["limit"] != 1]
average_accuracy = average_accuracy[average_accuracy["limit"] != 5]
average_accuracy = average_accuracy.groupby(["oversampling", "rescore", "limit"])[
    "accuracy"
].mean()
average_accuracy = average_accuracy.reset_index()
acc = average_accuracy.pivot(
    index="limit", columns=["oversampling", "rescore"], values="accuracy"
)
acc

oversampling,1,1,2,2,3,3
rescore,False,True,False,True,False,True
limit,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
10,0.534444,0.857778,0.534444,0.918889,0.533333,0.941111
20,0.508333,0.837778,0.508333,0.903889,0.508333,0.927778
50,0.492222,0.834444,0.492222,0.903556,0.492889,0.940889
100,0.499111,0.845444,0.498556,0.918333,0.497667,0.944556


In [5]:
markdown_table = acc.loc[:, (3.0, True)].to_markdown()
print(markdown_table)


|   limit |   (3, True) |
|--------:|------------:|
|      10 |    0.941111 |
|      20 |    0.927778 |
|      50 |    0.940889 |
|     100 |    0.944556 |
