In [0]:
%pip install -U databricks-sdk
%pip install -U sentence-transformers
%pip install -U mlflow
%pip install python-snappy==0.7.3
%pip install einops
%pip install torch==2.4.0 torchvision==0.19.0
dbutils.library.restartPython()

In [0]:
catalog = "tax_dev"
schema = "databricks_team"

spark.sql(f"USE CATALOG {catalog}")
spark.sql(f"USE SCHEMA {schema}")

In [0]:
from mlflow.tracking import MlflowClient

ft_adapter_model = "snowflake-arctic-embed-m-long-linear-adapter"
model_uc_path = f'{catalog}.{schema}.{ft_adapter_model}'


def get_latest_model_version(model_name):
  client = MlflowClient()
  model_version_infos = client.search_model_versions(f"name = '{model_name}'")
  return max([int(model_version_info.version) for model_version_info in model_version_infos])

# If instructor needs to update the model, the schema needs to change to SHARED_SCHEMA
latest_model_version = get_latest_model_version(model_uc_path)

In [0]:
model_uri = f"models:/{model_uc_path}/{latest_model_version}"
model_uri

In [0]:
import sentence_transformers

In [0]:
from mlflow.utils.model_utils import (
    _add_code_from_conf_to_system_path,
    _download_artifact_from_uri,
    _get_flavor_configuration_from_uri,
    _validate_and_copy_code_paths,
    _validate_and_prepare_target_save_path,
)

In [0]:
dst_path = "/Volumes/tax_dev/databricks_team/test"
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)

In [0]:
import logging
_logger = logging.getLogger(__name__)

FLAVOR_NAME = "sentence_transformers"
flavor_config = _get_flavor_configuration_from_uri(model_uri, FLAVOR_NAME, _logger)

In [0]:
flavor_config

In [0]:
import pathlib
print(local_model_path)
SENTENCE_TRANSFORMERS_DATA_PATH = "model.sentence_transformer"
local_model_dir = pathlib.Path(local_model_path).joinpath(SENTENCE_TRANSFORMERS_DATA_PATH)
local_model_dir

In [0]:
dbutils.fs.cp("file:/Workspace/Users/q.yu@databricks.com/embedding_model_adapters/custom_embedding_adapter/sbert-adapter/notebooks/adapters.py", str(local_model_dir))

In [0]:
from packaging.version import Version
def _get_load_kwargs():
    import sentence_transformers

    load_kwargs = {}
    # The trust_remote_code is supported since Sentence Transformers 2.3.0
    if Version(sentence_transformers.__version__) >= Version("2.3.0"):
        # Always set trust_remote_code=True because we save the entire repository files in
        # the model artifacts, so there is no risk of running untrusted code unless the logged
        # artifact is modified by a malicious actor, which is much more broader security
        # concern that even cannot be prevented by setting trust_remote_code=False.
        load_kwargs["trust_remote_code"] = True
    return load_kwargs

In [0]:
load_kwargs = _get_load_kwargs()
load_kwargs

In [0]:
from sentence_transformers import SentenceTransformer

In [0]:
str(local_model_dir)

In [0]:
SentenceTransformer(model_name_or_path=str(local_model_dir), trust_remote_code=True)

In [0]:
from transformers.dynamic_module_utils import get_class_from_dynamic_module