In [0]:
import mlflow
from config import DeployConfig
from PIL import Image 
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import StringType
import base64

In [0]:
dbutils.widgets.text("config_path", "./config/env_variables.yml")
config_path = dbutils.widgets.get("config_path")
cfg = DeployConfig.from_yaml(config_path)

In [0]:
image_source = getattr(cfg, f"image_source")
image_table = getattr(cfg, f"image_table")

In [0]:
def convert_to_base64(image_bytearray):
  return base64.b64encode(image_bytearray).decode('utf-8')

df_images = spark.read.format("binaryFile")\
                      .option("pathGlobFilter", "*.jpg")\
                      .load(f"{image_source.path}CAT_00")

window_spec = Window.orderBy(F.monotonically_increasing_id())
df_images = df_images.withColumn("id", F.row_number().over(window_spec))
df_images = df_images.withColumn("model_input", F.udf(convert_to_base64, StringType())(F.col("content"))) #### serving endpoint cannot take in byte array image, so need to convert to base64 encoded string

In [0]:
display(df_images)

In [0]:
df_images.write.format("delta").mode("overwrite").saveAsTable(image_table.path)