In [0]:
%pip install databricks_vectorsearch 
dbutils.library.restartPython()

In [0]:
from config import DeployConfig
import requests
import json

In [0]:
dbutils.widgets.text("config_path", "./config/env_variables.yml")
config_path = dbutils.widgets.get("config_path")
cfg = DeployConfig.from_yaml(config_path)

In [0]:
vs_index = getattr(cfg, f"vs_index")
image_table = getattr(cfg, f"image_table")
user_online_table=getattr(cfg, f"user_online_table")

In [0]:
user_online_table.path

#AGENT BUILD

In [0]:
import mlflow
from mlflow.deployments import get_deploy_client
from databricks.vector_search.index import VectorSearchIndex
from databricks.vector_search.client import VectorSearchClient
from mlflow.entities import SpanType
from PIL import Image
from io import BytesIO
import os
import requests

In [0]:
os.environ['CLIENT_ID'] = dbutils.secrets.get("shovakeemian-scope", "shovakeemian-sp-client-id")
os.environ['CLIENT_SECRET'] = dbutils.secrets.get("shovakeemian-scope", "shovakeemian-sp-client-secret")

In [0]:
class CAT_AGENT(mlflow.pyfunc.PythonModel):
    def __init__(self):
      # from databricks.sdk import WorkspaceClient
      import requests
      self.session=requests.Session()

      # self.w = WorkspaceClient()
      # self.client = get_deploy_client("databricks")

      self.oauth=None
      self.CLIENT_ID = os.environ.get("CLIENT_ID") 
      self.CLIENT_SECRET = os.environ.get("CLIENT_SECRET") 
      
      self.vsc=VectorSearchClient(    
            workspace_url="https://e2-demo-field-eng.cloud.databricks.com/",
            service_principal_client_id=self.CLIENT_ID,
            service_principal_client_secret=self.CLIENT_SECRET
        )
      
      self.online_table_url='https://80ad8fdc-e516-47cb-b709-85397e83bbf7.online-tables.cloud.databricks.com/api/2.0/workspace/1444828305810485/online/pgrest/ml_shovakeemian/user_pet_features_online'
      self.online_table_schema='feip'

    def load_context(self, context):
      from transformers import CLIPProcessor, CLIPModel
      from PIL import Image
      from io import BytesIO
      # Initialize tokenizer and model
      # will not need to do this once ai_query can take params of a pyfunc. We can have teh serving endpoint be able to do image and text embeddings. 
      self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
      self.processor= CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")


    @mlflow.trace(name="update_oauth_token")
    def update_oauth_token(self):
      url = 'https://e2-demo-field-eng.cloud.databricks.com/oidc/v1/token'
      data = {
          'grant_type': 'client_credentials',
          'client_id': self.CLIENT_ID,
          'client_secret': self.CLIENT_SECRET,
          'scope': 'all-apis',
          'authorization_details': '[{"type":"unity_catalog_permission","securable_type":"table","securable_object_name":"ml_shovakeemian.feip.user_pet_features_online","operation": "ReadOnlineView"}]'
      }

      response = requests.post(url, data=data)
      response.raise_for_status()
      self.oauth = response.json().get('access_token')
      return self.oauth

    @mlflow.trace(name="user_lookup", span_type=SpanType.RETRIEVER, attributes={"table": user_online_table.path})
    def _online_table_lookup(self, user):      
      headers = {
        "Authorization": f"Bearer {self.oauth}",
        "Accept-Profile": self.online_table_schema
        }
      query_params = {
        "select": "breed"
        ,"user_id": f"in.({user})"
        }
      try:
        response = self.session.get(self.online_table_url, headers=headers, params=query_params)
        cat_breed=response.json()[0]['breed']
      except:
        self.update_oauth_token()
        response = self.session.get(self.online_table_url, headers=headers, params=query_params)
        cat_breed=response.json()[0]['breed']
      
      cat_breed=cat_breed.replace('\u200b', '')
      return cat_breed


    @mlflow.trace(name="compute_text_embedding", span_type=SpanType.EMBEDDING, attributes={"model": "clip-vit-large-patch14"})
    def _get_text_embedding(self, text):
      """
      computes the text embedding for a given text.
      """
      #want to change this to call the serving endpoint when ai_query can take params of a pyfunc.
      inputs = self.processor(text=text, return_tensors="pt", padding=True)
      text_features = self.model.get_text_features(**inputs)
      return text_features.detach().numpy().tolist()[0]


    @mlflow.trace(name="cat_image_vs_imageemb_lookup", span_type=SpanType.RETRIEVER, attributes={"model": "clip-vit-large-patch14", "vs_index": vs_index.path})
    def _vector_search_retrieval(self, query):
      text_embed_query=self._get_text_embedding(query)
      index = self.vsc.get_index(endpoint_name=vs_index.endpoint, index_name=vs_index.path)
      vs_output = index.similarity_search(columns=["id", 'path'], query_vector=text_embed_query, num_results=3)
      return vs_output


    @mlflow.trace(name="cat_id_image_lookup", span_type=SpanType.RETRIEVER, attributes={"table": image_table.path})
    def _get_images(self, vs_output):
      """
      takes the output of vs and querys original delta table for images. 
      uses PIL to open images?
      """
      image_ids=[result[0] for result in vs_output['result']['data_array']]
      image_lookups=spark.sql(f'select content from {image_table.path} where id in ({",".join([str(id) for id in image_ids])})').collect()
      # return Image.open(BytesIO(image_lookups[0]['content']))
      for image in image_lookups:
        yield Image.open(BytesIO(image['content']))

    @mlflow.trace(name="quickstart-agent")
    def predict(self, context, user_id):
      cat_breed=self._online_table_lookup(user=user_id)
      vs_output=self._vector_search_retrieval(query=cat_breed)
      cats=self._get_images(vs_output)
      return cats

In [0]:
agent=CAT_AGENT()
agent.load_context(None)

In [0]:
user_id='9c94dad5-7afb-4e65-b182-9b5fe3a1eff3'

In [0]:
agent.update_oauth_token()

In [0]:
cat_breed=agent._online_table_lookup(user=user_id)

In [0]:
agent._vector_search_retrieval(query=cat_breed)

In [0]:
response=agent.predict(user_id=user_id, context=None)

In [0]:
next(response)

In [0]:
next(response)

In [0]:
next(response)

In [0]:
# import mlflow
# from mlflow.deployments import get_deploy_client


# class QAChain(mlflow.pyfunc.PythonModel):
#     def __init__(self):
#         self.client = get_deploy_client("databricks")

#     @mlflow.trace(name="quickstart-chain")
#     def predict(self, model_input, system_prompt, params):
#         messages = [
#                 {
#                     "role": "system",
#                     "content": system_prompt,
#                 },
#                 {
#                     "role": "user",
#                     "content":  model_input[0]["query"]
#                 }
#           ]
        
#         traced_predict = mlflow.trace(self.client.predict)
#         output = traced_predict(
#             endpoint=params["model_name"],
#             inputs={
#                 "temperature": params["temperature"],
#                 "max_tokens": params["max_tokens"],
#                 "messages": messages,
#             },
#         )
        
#         with mlflow.start_span(name="_final_answer") as span:
#             span.set_inputs({"query": model_input[0]["query"]})

#             answer = output["choices"][0]["message"]["content"]

#             span.set_outputs({"generated_text": answer})     
#      # Attributes computed at runtime can be set using the set_attributes() method.
#      span.set_attributes({
#        "model_name": params["model_name"],
#                 "prompt_tokens": output["usage"]["prompt_tokens"],
#                 "completion_tokens": output["usage"]["completion_tokens"],
#                 "total_tokens": output["usage"]["total_tokens"]
#             })
#         return answer