In [None]:
!pip install transformers==4.44.2 mlflow==2.20.1 torch==2.3.1
%restart_python

In [None]:
hugginface_model_id = dbutils.widgets.get("hugginface_model_id")

uc_model_name = hugginface_model_id.split("/")[-1]

print(hugginface_model_id, uc_model_name)

In [None]:
# MLFlow task type
task = "llm/v1/chat"

In [None]:
import mlflow
import transformers

model = transformers.AutoModelForCausalLM.from_pretrained(hugginface_model_id)
tokenizer = transformers.AutoTokenizer.from_pretrained(hugginface_model_id)

transformers_model = {"model": model, "tokenizer": tokenizer}

with mlflow.start_run():
    model_info = mlflow.transformers.log_model(
        transformers_model=transformers_model,
        artifact_path="model",
        task=task,
        registered_model_name=uc_model_name,
        metadata={
            "task": task,
            "pretrained_model_name": "meta-llama/Llama-3.3-8B-Instruct",
            "databricks_model_family": "LlamaForCausalLM",
            "databricks_model_size_parameters": "8b",
        },
        await_registration_for=1500,
    )

In [None]:
dbutils.jobs.taskValues.set("model_version", model_info.registered_model_version)
dbutils.jobs.taskValues.set("uc_model_name", uc_model_name)