In [0]:
%pip install -e ..
%restart_python

In [0]:
from pathlib import Path
import sys
sys.path.append(str(Path.cwd().parent / 'src'))

In [0]:

import os
import mlflow
import yaml
from loguru import logger
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils

from honeywell.config import GitTagsFromWidgets, ProjectConfig, Tags
from honeywell.models.basic_model import BasicModel
from pyspark.sql import SparkSession

from dotenv import load_dotenv

spark = SparkSession.builder.getOrCreate()
dbutils = DBUtils(spark)



In [0]:
# COMMAND ----------
# Set up Databricks or local MLflow tracking
def is_databricks():
    return "DATABRICKS_RUNTIME_VERSION" in os.environ



In [0]:
# COMMAND ----------
# If you have DEFAULT profile and are logged in with DEFAULT profile,
# skip these lines

if not is_databricks():
    load_dotenv()
    profile = os.environ["PROFILE"]
    mlflow.set_tracking_uri(f"databricks://{profile}")
    mlflow.set_registry_uri(f"databricks-uc://{profile}")



In [0]:
# COMMAND ----------
dbutils.widgets.text("env", "dev")
working_env = dbutils.widgets.get("env")
print(f"Working environment: {working_env}")
config = ProjectConfig.from_yaml(config_path="../project_config_honeywell.yml", env=working_env)
logger.info("Configuration loaded:")
logger.info(yaml.dump(config, default_flow_style=False))

git_tags = GitTagsFromWidgets.from_widgets(dbutils)
logger.info("Git tags loaded:")
tags = Tags(**{
    "git_sha": git_tags.git_sha,
    "branch": git_tags.branch
})
# COMMAND ----------
# Initialize model with the config path
basic_model = BasicModel(config=config,
                         tags=tags,
                         spark=spark)


print("I'm a new line")

In [0]:
# COMMAND ----------
basic_model.load_data()
basic_model.prepare_features()

# COMMAND ----------
basic_model.train()

# COMMAND ----------
run_id , model_uri, model_name= basic_model.log_model()


In [0]:


# COMMAND ----------
latest_version = basic_model.register_model()

In [0]:
# COMMAND ----------
# Inside train_model script
dbutils.jobs.taskValues.set(key="model_uri", value=model_uri)
dbutils.jobs.taskValues.set(key="candidate_run_id", value=run_id)
dbutils.jobs.taskValues.set(key="model_name", value=model_name)
dbutils.jobs.taskValues.set(key="latest_version", value=str(latest_version))