In [0]:
import mlflow
import torch
from config import DeployConfig
import pandas as pd
import requests
import json

In [0]:
print(mlflow.__version__)
print(torch.__version__)
print(torch.version.cuda)

In [0]:
dbutils.widgets.text("config_path", "./config/env_variables.yml")
config_path = dbutils.widgets.get("config_path")
cfg = DeployConfig.from_yaml(config_path)

In [0]:
image_table = getattr(cfg, f"image_table")
dev_model = getattr(cfg, f"dev_model")
endpoint_name = getattr(cfg, f"endpoint_name")

# PYFUNC MODELS

In [0]:
# JUST IMAGE EMBEDDING
class CLIP_IMAGE_EMBEDDING(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        from transformers import CLIPProcessor, CLIPModel
        # Initialize tokenizer and model
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor= CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    
    def _get_image_embedding_bytearray(self, base64_image_str):
        import base64
        from PIL import Image
        import requests
        from io import BytesIO
        #decode base64string back to bytearray
        decoded_bytearray = bytearray(base64.b64decode(base64_image_str))
        image = Image.open(BytesIO(decoded_bytearray))
        inputs = self.processor(images=image, return_tensors="pt")
        image_features = self.model.get_image_features(**inputs)
        return image_features.detach().numpy().tolist()[0]

    def predict(self, context, df):
        return df['model_input'].apply(lambda x: self._get_image_embedding_bytearray(x))

In [0]:
# IMAGE AND TEXT EMBEDDING MODEL
class CLIP_IMAGE_TEXT_EMBEDDING(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        from transformers import CLIPProcessor, CLIPModel
        # Initialize tokenizer and model
        self.model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
        self.processor= CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14")
    
    def _get_image_embedding_bytearray(self, base64_image_str):
        from PIL import Image
        import requests       
        from io import BytesIO
        import base64
        #decode base64string back to bytearray
        decoded_bytearray = bytearray(base64.b64decode(base64_image_str))
        image = Image.open(BytesIO(decoded_bytearray))
        inputs = self.processor(images=image, return_tensors="pt")
        image_features = self.model.get_image_features(**inputs)
        return image_features.detach().numpy().tolist()[0]

    def _get_text_embedding(self, text):
        inputs = self.processor(text=text, return_tensors="pt", padding=True)
        text_features = self.model.get_text_features(**inputs)
        return text_features.detach().numpy().tolist()[0]

    def predict(self, context, df, params):
        input_type=params.get('input_type')
        if input_type.lower()=='image':
          print('embedding image')
          return df['model_input'].apply(lambda x: self._get_image_embedding_bytearray(x))
        elif input_type.lower()=='text':
          print('embedding text')
          return df['model_input'].apply(lambda x: self._get_text_embedding(x))


# TESTING PYFUNC

In [0]:
image_test_pd=spark.sql(f'select model_input from {image_table.path} limit 2').toPandas()
image_test_pd

In [0]:
clip=CLIP_IMAGE_EMBEDDING()
clip.load_context(context=None)
test_result=clip.predict(context=None, df=image_test_pd)
test_result

In [0]:

clip=CLIP_IMAGE_TEXT_EMBEDDING()
clip.load_context(context=None)
test_result=clip.predict(context=None, df=image_test_pd, params={'input_type':'image'})
test_result

In [0]:
df_text_input=pd.DataFrame({'model_input':['hello world', 'hello world']})
test_result=clip.predict(context=None, df=df_text_input, params={'input_type':'text'})
test_result

In [0]:
from PIL import Image
from io import BytesIO
#check out image
test_image=spark.sql(f'select content from {image_table.path} limit 1').collect()[0]['content']
Image.open(BytesIO(test_image))

# LOG MODEL


In [0]:
pip_requirements=[
  "--extra-index-url https://download.pytorch.org/whl/cu121", 
  "mlflow==2.15.1",
  "setuptools<70.0.0", 
  "torch==2.3.1+cu121", 
  "accelerate==0.31.0", 
  "astunparse==1.6.3", 
  "bcrypt==3.2.0", 
  "boto3==1.34.39", 
  "configparser==5.2.0", 
  "defusedxml==0.7.1", 
  "dill==0.3.6",
   "google-cloud-storage==2.10.0", 
   "ipython==8.15.0", 
   "lz4==4.3.2", 
   "nvidia-ml-py==12.555.43", 
   "optree==0.12.1", 
   "pandas==1.5.3", 
   "pyopenssl==23.2.0", 
   "pytesseract==0.3.10", 
   "scikit-learn==1.3.0", 
   "sentencepiece==0.1.99", 
   "torchvision==0.18.1+cu121", 
   "transformers==4.41.2",
   "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.3cxx11abiFALSE-cp311-cp311-linux_x86_64.whl"
   ]

In [0]:
from mlflow.models.signature import ModelSignature, infer_signature

signature1 = infer_signature(image_test_pd, [test_result[0]])
signature2 = infer_signature(image_test_pd, [test_result[0]], params={'input_type':'text'})

In [0]:
mlflow.set_experiment("/Users/sara.hovakeemian@databricks.com/FE-IP/experiments/clip")

In [0]:
#image only embedding model
with mlflow.start_run(run_name='clip_image_only') as run:  
    mlflow.pyfunc.log_model(
        registered_model_name=f'{dev_model.path}',
        python_model=CLIP_IMAGE_EMBEDDING(),
        artifact_path="clip_image_only",
        signature=signature1,
        pip_requirements=pip_requirements
    )

In [0]:
client = mlflow.tracking.MlflowClient()

In [0]:
client.set_registered_model_alias(name=f'{dev_model.path}', alias='champion', version=1)

client.update_model_version(
    name=f'{dev_model.path}',
    version=1,
    description="Only does image embeddings using CLIP"
)

In [0]:
#image only embedding model
with mlflow.start_run(run_name='clip_image_text') as run:  
    mlflow.pyfunc.log_model(
        registered_model_name=f'{dev_model.path}',
        python_model=CLIP_IMAGE_TEXT_EMBEDDING(),
        artifact_path="clip_image_text",
        signature=signature2,
        pip_requirements=pip_requirements
    )

In [0]:
client.set_registered_model_alias(name=f'{dev_model.path}', alias='challenger', version=2)

client.update_model_version(
    name=f'{dev_model.path}',
    version=2,
    description="Can do image and text embeddings using CLIP"
)

# SERVE MODEL

In [0]:
notebook_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Define the endpoint URL and headers
url = "https://e2-demo-field-eng.cloud.databricks.com/api/2.0/serving-endpoints"
headers = {
    "Authorization": f"Bearer {notebook_token}",
    "Content-Type": "application/json"
}

# Define the payload for creating the model serving endpoint
payload = {
    "name": endpoint_name,
    "config": {
        "served_entities": [
            {
                "entity_name": f"{dev_model.path}",
                "entity_version": 1,
                "workload_size": "Medium",
                "scale_to_zero_enabled": True,
                "workload_type": "GPU_SMALL"
            }
        ]
    }
}

# Make the POST request to create the serving endpoint
response = requests.post(url, headers=headers, data=json.dumps(payload))

# Check the response status
if response.status_code == 200:
    print("Model serving endpoint created successfully.")
else:
    print(f"Failed to create model serving endpoint: {response.text}")

# QUERY ENDPOINT

In [0]:
image_table.path

In [0]:
image_test_pd=spark.sql(f'select model_input from {image_table.path} limit 1').toPandas()
image_base_64=image_test_pd.head(1).iloc[0]['model_input']


# Define the model serving endpoint URL
endpoint_url = f"https://e2-demo-field-eng.cloud.databricks.com/serving-endpoints/{endpoint_name}/invocations"

input_data = {
  "inputs" : [image_base_64]
  # ,"params" : {'input_type':'image'} #use if using the model that can produce text and image embeddings
}

notebook_token = dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().get()

# Set the headers for the request
headers = {
    "Content-Type": "application/json",
    "Authorization": f'Bearer {notebook_token}'
}

# Make the request to the model serving endpoint
response = requests.post(endpoint_url, headers=headers, data=json.dumps(input_data))

# Parse the response
response_data = response.json()

# Display the response data
display(response_data)