In [0]:
%pip install databricks_vectorsearch 
dbutils.library.restartPython()

In [0]:
from config import DeployConfig
import requests
import json
from databricks.vector_search.client import VectorSearchClient

In [0]:
dbutils.widgets.text("config_path", "./config/env_variables.yml")
config_path = dbutils.widgets.get("config_path")
cfg = DeployConfig.from_yaml(config_path)

In [0]:
vs_index = getattr(cfg, f"vs_index")
image_table = getattr(cfg, f"image_table")

#VS INDEX CREATION

In [0]:
vsc = VectorSearchClient()

In [0]:
spark.sql(f'ALTER TABLE {vs_index.delta_sync_table} SET TBLPROPERTIES (delta.enableChangeDataFeed = true)')

In [0]:
vsc.create_delta_sync_index(
    endpoint_name=vs_index.endpoint,
    index_name=vs_index.path,
    source_table_name=vs_index.delta_sync_table,
    columns_to_sync=['id', 'path'],
    pipeline_type="TRIGGERED",
    primary_key="id",
    embedding_dimension=768,
    embedding_vector_column="image_embeddings"
)

# QUERY VS INDEX

In [0]:
import mlflow
from transformers import CLIPProcessor, CLIPModel
from PIL import Image
from io import BytesIO

In [0]:
# text embedding
# will eventually call serving endpoint
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")

In [0]:
def get_text_embedding(text):
    inputs = processor(text=text, return_tensors="pt", padding=True)
    text_features = model.get_text_features(**inputs)
    return text_features.detach().numpy().tolist()[0]

In [0]:
text="black cat in forest"
query=get_text_embedding(text)

from databricks.vector_search.client import VectorSearchClient
vsc = VectorSearchClient()
index = vsc.get_index(endpoint_name=vs_index.endpoint, index_name=vs_index.path)

In [0]:
resp = index.similarity_search(columns=["id", 'path'], query_vector=query, num_results=3)
resp

In [0]:
ids=[result[0] for result in resp['result']['data_array']]
ids

In [0]:
# print(resp)
result=spark.sql(f'select content from {image_table.path} where id in ({",".join([str(id) for id in ids])})').collect()

In [0]:
Image.open(BytesIO(result[0]['content']))
# Image.open(BytesIO(result[1]['content']))
# Image.open(BytesIO(result[2]['content']))