In [0]:
%pip install databricks_vectorsearch 
dbutils.library.restartPython()

In [0]:
from config import DeployConfig
import requests
import json

In [0]:
dbutils.widgets.text("config_path", "./config/env_variables.yml")
config_path = dbutils.widgets.get("config_path")
cfg = DeployConfig.from_yaml(config_path)

In [0]:
vs_index = getattr(cfg, f"vs_index")
image_table = getattr(cfg, f"image_table")

#AGENT BUILD

In [0]:
import mlflow
from mlflow.deployments import get_deploy_client
from databricks.vector_search.index import VectorSearchIndex
from databricks.vector_search.client import VectorSearchClient
from mlflow.entities import SpanType
from PIL import Image
from io import BytesIO
import os
import requests

In [0]:
os.environ['CLIENT_ID'] = dbutils.secrets.get("shovakeemian-scope", "shovakeemian-sp-client-id")
os.environ['CLIENT_SECRET'] = dbutils.secrets.get("shovakeemian-scope", "shovakeemian-sp-client-secret")

In [0]:
class CAT_AGENT(mlflow.pyfunc.PythonModel):
    def __init__(self):
      # from databricks.sdk import WorkspaceClient
      import requests
      self.session=requests.Session()

      # self.w = WorkspaceClient()
      # self.client = get_deploy_client("databricks")

      self.oauth=None
      self.CLIENT_ID = os.environ.get("CLIENT_ID") 
      self.CLIENT_SECRET = os.environ.get("CLIENT_SECRET") 
      
      self.vsc=VectorSearchClient(    
            workspace_url="https://e2-demo-field-eng.cloud.databricks.com/",
            service_principal_client_id=self.CLIENT_ID,
            service_principal_client_secret=self.CLIENT_SECRET
        )

    def load_context(self, context):
      from transformers import CLIPProcessor, CLIPModel
      from PIL import Image
      from io import BytesIO
      # Initialize tokenizer and model
      # will not need to do this once ai_query can take params of a pyfunc. We can have teh serving endpoint be able to do image and text embeddings. 
      self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
      self.processor= CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")


    @mlflow.trace(name="compute_text_embedding", span_type=SpanType.EMBEDDING, attributes={"model": "clip-vit-large-patch14"})
    def _get_text_embedding(self, text):
      """
      computes the text embedding for a given text.
      """
      #want to change this to call the serving endpoint when ai_query can take params of a pyfunc.
      inputs = self.processor(text=text, return_tensors="pt", padding=True)
      text_features = self.model.get_text_features(**inputs)
      return text_features.detach().numpy().tolist()[0]


    @mlflow.trace(name="cat_image_vs_imageemb_lookup", span_type=SpanType.RETRIEVER, attributes={"model": "clip-vit-large-patch14", "vs_index": vs_index.path})
    def _vector_search_retrieval(self, query):
      text_embed_query=self._get_text_embedding(query)
      index = self.vsc.get_index(endpoint_name=vs_index.endpoint, index_name=vs_index.path)
      vs_output = index.similarity_search(columns=["id", 'path'], query_vector=text_embed_query, num_results=3)
      return vs_output


    @mlflow.trace(name="cat_id_image_lookup", span_type=SpanType.RETRIEVER, attributes={"table": image_table.path})
    def _get_images(self, vs_output):
      """
      takes the output of vs and querys original delta table for images. 
      uses PIL to open images?
      """
      image_ids=[result[0] for result in vs_output['result']['data_array']]
      image_lookups=spark.sql(f'select content from {image_table.path} where id in ({",".join([str(id) for id in image_ids])})').collect()
      # return Image.open(BytesIO(image_lookups[0]['content']))
      for image in image_lookups:
        yield Image.open(BytesIO(image['content']))

    @mlflow.trace(name="quickstart-agent")
    def predict(self, context, text_query):
      vs_output=self._vector_search_retrieval(query=text_query)
      cats=self._get_images(vs_output)
      return cats

In [0]:
agent=CAT_AGENT()
agent.load_context(None)

In [0]:
response=agent.predict(text_query='black cat', context=None)

In [0]:
next(response)

In [0]:
next(response)

In [0]:
next(response)