# How to Choose an Embedding Model

In [1]:
!pip3 install pytrec_eval



In [2]:
import pandas as pd
pd.set_option('max_colwidth', None)

df = pd.read_csv('pastry_data.csv', delimiter=';')
df

Unnamed: 0,pastry_id,pastry_name,pastry_description
0,1,Bagel,"A classic round bread with a chewy interior and a golden crust, perfect for toasting or topping with cream cheese, smoked salmon, or your favorite spreads."
1,2,Roll,"Soft and pillowy, these rolls are versatile companions to any meal. Enjoy them fresh out of the oven or sliced for sandwiches, filled with your choice of meats, cheeses, and veggies."
2,3,Donut,"Indulge in these sweet, fried delights. Glazed, powdered, or filled with decadent creams and fruit jams, each bite is a delightful burst of flavor and nostalgia."
3,4,Muffin,"Moist and tender, these muffins are bursting with flavor. Whether you prefer classic blueberry, indulgent chocolate chip, or hearty bran, there's a muffin for every craving."
4,5,Croissant,"Buttery layers of flaky pastry are folded to perfection, creating a delicate and irresistible treat. Enjoy them plain, stuffed with savory fillings, or paired with your favorite coffee."
5,6,Scone,"Crumbly yet tender, these scones are the epitome of comfort. Enjoy them plain or studded with fruits, nuts, or chocolate chips, accompanied by a dollop of clotted cream and jam."
6,7,Pretzel,"Crunchy on the outside, soft on the inside, these pretzels are a salty snack lover's dream. Enjoy them twisted into traditional shapes or dipped in sweet or savory toppings for a delightful twist."
7,8,Sandwich,"Freshly baked bread is the foundation for these hearty sandwiches. Pile on layers of meats, cheeses, crisp vegetables, and flavorful spreads for a satisfying meal on the go."


In [3]:
import weaviate
import json

try:
    client= weaviate.connect_to_local(port=8079, grpc_port=50050)
except:
    client = weaviate.connect_to_embedded(
        headers = {
            "X-OpenAI-Api-Key": os.environ["OPENAI_API_KEY"]
        }
    )

print(client.is_ready())

Started /Users/leonie/.cache/weaviate-embedded: process ID 87673


{"action":"startup","default_vectorizer_module":"none","level":"info","msg":"the default vectorizer modules is set to \"none\", as a result all new schema classes without an explicit vectorizer setting, will use this vectorizer","time":"2024-05-08T12:29:09+02:00"}
{"action":"startup","auto_schema_enabled":true,"level":"info","msg":"auto schema enabled setting is set to \"true\"","time":"2024-05-08T12:29:09+02:00"}
{"level":"info","msg":"No resource limits set, weaviate will use all available memory and CPU. To limit resources, set LIMIT_RESOURCES=true","time":"2024-05-08T12:29:09+02:00"}
{"action":"grpc_startup","level":"info","msg":"grpc server listening at [::]:50050","time":"2024-05-08T12:29:09+02:00"}
{"action":"restapi_management","level":"info","msg":"Serving weaviate at http://127.0.0.1:8079","time":"2024-05-08T12:29:09+02:00"}


True


In [4]:
import weaviate.classes as wvc
from weaviate.classes.config import Property, DataType

if client.collections.exists("Pastries"):
    client.collections.delete("Pastries")

pastries = client.collections.create(
    name="Pastries",
    vectorizer_config=wvc.config.Configure.Vectorizer.text2vec_openai(model='text-embedding-3-large'),
    properties=[
        Property(name="pastry_name", data_type=DataType.TEXT),
        Property(name="pastry_description", data_type=DataType.TEXT),
    ]
)

pastry_objects = list()
for _, row in df.iterrows():
    properties = {
        "pastry_name": row.pastry_name,
        "pastry_description": row.pastry_description
    }
    pastry_objects.append(properties)

pastries.data.insert_many(pastry_objects)

{"level":"info","msg":"Created shard pastries_6h48j8vP30f0 in 1.561167ms","time":"2024-05-08T12:29:09+02:00"}
{"action":"hnsw_vector_cache_prefill","count":1000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2024-05-08T12:29:09+02:00","took":41958}
{"level":"info","msg":"Completed loading shard dim_512_vIr0JrwWdLwP in 3.483791ms","time":"2024-05-08T12:29:10+02:00"}
{"action":"hnsw_vector_cache_prefill","count":3000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2024-05-08T12:29:10+02:00","took":219958}
{"level":"info","msg":"Completed loading shard dimensions__512_ALiufnMMSrhQ in 5.809542ms","time":"2024-05-08T12:29:10+02:00"}
{"level":"info","msg":"Completed loading shard dimension_1536_xLwNXnHZeLcm in 6.313459ms","time":"2024-05-08T12:29:10+02:00"}
{"action":"hnsw_vector_cache_prefill","count":3000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"

BatchObjectReturn(all_responses=[UUID('74d4e157-886b-4075-9849-5b5c30733ce6'), UUID('8556cc0c-f8f3-4bd2-95d7-76aee821e1df'), UUID('bc49a7d3-c2c4-4fed-80fb-717aff02f8b3'), UUID('5660da7c-396d-4663-bfe1-6439e8b261e4'), UUID('47ca5c52-f841-45d7-99c7-68136d1abfa3'), UUID('32cb8665-c9b3-4482-840d-6003d6f22672'), UUID('342e40f9-f352-4155-803b-f3e644e6a6a6'), UUID('6a09544a-21f2-42b4-9636-a45361cfb887')], elapsed_seconds=0.7803750038146973, errors={}, uuids={0: UUID('74d4e157-886b-4075-9849-5b5c30733ce6'), 1: UUID('8556cc0c-f8f3-4bd2-95d7-76aee821e1df'), 2: UUID('bc49a7d3-c2c4-4fed-80fb-717aff02f8b3'), 3: UUID('5660da7c-396d-4663-bfe1-6439e8b261e4'), 4: UUID('47ca5c52-f841-45d7-99c7-68136d1abfa3'), 5: UUID('32cb8665-c9b3-4482-840d-6003d6f22672'), 6: UUID('342e40f9-f352-4155-803b-f3e644e6a6a6'), 7: UUID('6a09544a-21f2-42b4-9636-a45361cfb887')}, has_errors=False)

# Setup Evaluation

In [5]:
from weaviate.classes.query import MetadataQuery
import pytrec_eval

def get_results(query):
    response = pastries.query.near_text(
        query=query,
        limit=4,
        return_metadata=MetadataQuery(distance=True)
    )

    result_dict = {}
    for o in response.objects:
        result_dict[o.properties['pastry_name']] = o.metadata.distance

    return result_dict

qrel = {
    'Sweet pastry' : {
        'Donut' : 1,
        'Muffin' : 1,
        'Scone' : 1, 
    },
    'Suitable for lunch' : {
        'Sandwich' : 1,
        'Bagel' : 1,
        'Roll' : 1,
        'Pretzel' : 1
    },
    'Goes well with jam' : {
        'Bagel' : 1,
        'Croissant' : 1,
        'Roll' : 1,
    },
}

{"action":"hnsw_vector_cache_prefill","count":3000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2024-05-08T12:29:10+02:00","took":6396625}
{"action":"hnsw_vector_cache_prefill","count":3000,"index_id":"main","level":"info","limit":1000000000000,"msg":"prefilled vector cache","time":"2024-05-08T12:29:10+02:00","took":9203584}


In [6]:
queries = ["Sweet pastry", "Suitable for lunch", "Goes well with jam"]

run = {}
for q in queries:
    run[q] = get_results(q)

run

{'Sweet pastry': {'Croissant': 0.4081610441207886,
  'Donut': 0.4437391757965088,
  'Scone': 0.4496232271194458,
  'Sandwich': 0.48311829566955566},
 'Suitable for lunch': {'Sandwich': 0.6699223518371582,
  'Roll': 0.710200846195221,
  'Bagel': 0.7621301412582397,
  'Croissant': 0.7663571238517761},
 'Goes well with jam': {'Scone': 0.6245813369750977,
  'Bagel': 0.6838458776473999,
  'Donut': 0.6856768131256104,
  'Sandwich': 0.7056261301040649}}

In [7]:
# Evaluate
evaluator = pytrec_eval.RelevanceEvaluator(
    qrel, { 'recall.4', 'P.4'})

print(json.dumps(evaluator.evaluate(run), indent=1))

{
 "Sweet pastry": {
  "P_4": 0.5,
  "recall_4": 0.6666666666666666
 },
 "Suitable for lunch": {
  "P_4": 0.75,
  "recall_4": 0.75
 },
 "Goes well with jam": {
  "P_4": 0.25,
  "recall_4": 0.3333333333333333
 }
}
