In [None]:
%pip install -r requirements.txt

In [None]:
%pip install databricks-sdk==0.50.0
%restart_python

In [None]:
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig
from PIL import Image
import requests, torch
import mlflow
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

In [None]:
import yaml 

with open("../configs/config.yaml", "r") as f:
    config = yaml.safe_load(f)

catalog_name = config.get("catalog_name")
schema_name = config.get("schema_name")
volume_name = config.get("volume_name")
volume_folder = config.get("volume_folder")
model_name = config.get("model_name")
revision = config.get("revision")
uc_model_name = f"{catalog_name}.{schema_name}.{model_name.split("/")[-1]}"
served_model_name = config.get("served_model_name")

In [None]:
import os
cache_volume =  f"/Volumes/{catalog_name}/{schema_name}/{}/{volume_name}/{revision}/{volume_folder}"
cache_hf = "/local_disk0/hf_cache"
cache_local = "/local_disk0/{volume_folder}" 

os.environ["HF_HOME"] = cache_hf
os.environ["HF_HUB_CACHE"] = cache_hf
os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "True"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "1000"
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'  # Enables optimized download backend

In [None]:
import shutil
import os

# Copy volume cache to local cache if not already there
if not os.path.exists(cache_local):
    try: 
        print(f"Loading model from {cache_volume} to {cache_local}.")
        snapshots_dir = '/'.join(cache_local.split('/')[:-1])
        if not os.path.exists(snapshots_dir):
            os.makedirs(snapshots_dir)
        
        shutil.copytree(cache_volume, cache_local) 
        print(f"Successfully loaded model from {cache_volume} to {cache_local}!")
    except Exception as e: 
        print(f"Error: {e}")
else:
    print("File exists locally")

In [None]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained(cache_local)

# Set pad_token 
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(
    cache_local,
    torch_dtype=torch.bfloat16,
    device_map="auto"
)

In [None]:
import pandas as pd
import torch
from transformers import AutoProcessor, AutoModelForCausalLM, BitsAndBytesConfig, AutoTokenizer
import mlflow.pyfunc

# TODO: Update class name to your preferred name
class HFModelPyfunc(mlflow.pyfunc.PythonModel):
    def load_context(self, context):

        self.model_id = context.artifacts["model-weights"] 
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype  = torch.bfloat16 if self.device == "cuda" else torch.float32   

        print("************************************")
        print(f"Device: {self.device}, dtype: {self.dtype}")
        print(f"Loading model {self.model_id} to {self.device}")
        print("************************************")
        
        self.tokenizer = AutoTokenizer.from_pretrained(self.model_id)

        if self.tokenizer.pad_token is None:
            self.tokenizer.pad_token = self.tokenizer.eos_token

        self.model = AutoModelForCausalLM.from_pretrained(
            self.model_id,
            torch_dtype=self.dtype
        ).to(device=self.device)

    def predict(self, model_input: pd.DataFrame, params: dict = None) -> pd.Series:
        outputs = []
        max_tokens = params.get("max_tokens", 1024) if params else 1024
        
        for _, row in model_input.iterrows():
            # TODO: Update system_prompt and user_prompt 
            system_prompt = row.get("system_prompt", "You are a helpful medical assistant.")
            user_prompt = row.get("user_prompt", "")
            messages = [
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_prompt},
            ]

            raw_inputs = self.tokenizer.apply_chat_template(
                messages,
                add_generation_prompt=True,
                tokenize=True,
                return_dict=True,
                return_tensors="pt",
            ).to(self.device)

            inputs = {k: (v.to(self.dtype) if v.is_floating_point() else v) for k, v in raw_inputs.items()}

            prompt_len = inputs["input_ids"].size(-1)

            with torch.inference_mode():
                generation = self.model.generate(**inputs, max_new_tokens=max_tokens, do_sample=False)

            generated_tokens = generation[0][prompt_len:]
            text = self.tokenizer.decode(generated_tokens, skip_special_tokens=True)
            outputs.append(text)

        return pd.Series(outputs)

In [None]:
import pandas as pd
import json

df = pd.DataFrame({
    "system_prompt": ["You are a helpful medical assistant."],
    "user_prompt": ["What are the symptoms of hypertension?"]
})

class Context:
    def __init__(self, artifacts):
        self.artifacts = artifacts

custom_context = Context(artifacts={"model-weights": str(cache_local)})

custom_hf_model = HFModelPyfunc()
custom_hf_model.load_context(custom_context)

output = custom_hf_model.predict(df)

print(output[0])

%md
### Bug when logging artifacts where model serving will not use the full path:
- Fix: When using a PythonModel during serving, you access artifact files with paths like context.artifacts['artifact-key']; these keys correspond to directories or files under `<model_root>` 

In [None]:
import sys
import os
from mlflow.models import infer_signature
import mlflow
import pandas as pd
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import Schema, ColSpec

mlflow.set_tracking_uri("databricks")
mlflow.set_registry_uri("databricks-uc")

In [None]:
# This will infere signature from the input and output dataframes
signature = infer_signature(
  model_input=df, 
  model_output=output,
  params={"max_tokens": 512}
  ) # Doing strict schema to avoid rerunning pipeline

with mlflow.start_run():
    model_info = mlflow.pyfunc.log_model(
        artifact_path="hf_model_pyfunc",
        python_model=HFModelPyfunc(),
        signature=signature,
        pip_requirements="requirements.txt",
        # extra_pip_requirements=package_versions,  
        artifacts={
            'model-weights': cache_local},
        input_example = df
        )

In [None]:

try:
    result = mlflow.register_model(
        model_uri=model_info.model_uri,
        name=uc_model_name
    )
    print(f"Registered model version: {result.version}")

except Exception as e:
    print(f"Error registering model: {e}")

In [None]:
import mlflow

client = mlflow.MlflowClient()
client.set_registered_model_alias(
    name=uc_model_name,
    alias="challenger",
    version=result.version
)

In [None]:
import mlflow 
model_uri = f"models:/{uc_model_name}@challenger"
print(model_uri)

In [None]:
loaded_model = mlflow.pyfunc.load_model(model_uri)

In [None]:

import pandas as pd

df = pd.DataFrame({
    "system_prompt": ["You are a helpful medical assistant."],
    "user_prompt": ["What are the symptoms of hypertension?"]
})

outputs = loaded_model.predict(df, params={"max_tokens": 1000})

print(outputs[0])

In [None]:
import requests
from datetime import datetime, timedelta
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import (
        EndpointCoreConfigInput,
        ServedEntityInput,
        AutoCaptureConfigInput,
        ServingEndpointDetailed,
        ServingModelWorkloadType,
        EndpointTag
    )

model_version = client.get_model_version_by_alias(uc_model_name, "Champion")
served_entity_name = {model_name.split("/")[-1]}
user_email = dbutils.notebook.entry_point.getDbutils().notebook().getContext().userName().get()

served_entities = [
    ServedEntityInput(
        entity_name=model_name,
        entity_version=model_version,
        name=served_entity_name,
        workload_type=ServingModelWorkloadType.GPU_LARGE,
        workload_size="Small",
        scale_to_zero_enabled=True,
    )
]
auto_capture_config = AutoCaptureConfigInput(
    catalog_name=catalog_name,
    schema_name=schema_name,
    table_name_prefix=f"{model_name}_serving",
    enabled=True,
)

w = WorkspaceClient()

endpoint_details = w.serving_endpoints.create_and_wait(
            name=f"{served_entity_name}_endpoint",
            config=EndpointCoreConfigInput(
                name=f"{served_entity_name}_endpoint",
                served_entities=served_entities,
                auto_capture_config=None
            ),
            tags=[
                EndpointTag(key="application", value=served_entity_name),
                EndpointTag(key="created_by", value=user_email)
            ],
            timeout = timedelta(minutes=180) # wait up to three hours
        )