In [1]:
%load_ext autoreload
%autoreload 2

In [20]:
# add kernel pipenv run python -m ipykernel install --user --name

In [13]:
import boto3
import io
import json
import mlflow
import mlflow.xgboost
import os
import pandas as pd
import xgboost as xgb

from datetime import datetime
from hyperopt import fmin, tpe, hp, Trials, STATUS_OK
from mlflow import log_metric, log_param, log_artifact
from mlflow.entities import ViewType
from mlflow.tracking import MlflowClient
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import train_test_split
from typing import Dict

In [3]:
DATA_PATH = "../data/student_performance_data.csv"
TARGET_COLUMN = "GPA" 
STRATIFY_COLUMN = "GradeClass"
SEED = 42

In [4]:
with open("../terraform_env.json", "r") as f:
    config = json.load(f)

In [5]:
MLFLOW_TRACKING_URI = f"postgresql://{config.get('RDS_USERNAME')}:{config.get('RDS_PASSWORD')}@{config.get('RDS_ENDPOINT')}/{config.get('RDS_DB_NAME')}"
MLFLOW_S3_BUCKET = config.get('SAVE_BUCKET')

In [60]:
os.environ["db_uri"] = MLFLOW_TRACKING_URI

In [64]:
! pipenv run mlflow db upgrade $db_uri

2024/08/04 00:07:04 INFO mlflow.store.db.utils: Updating database tables
INFO  [alembic.runtime.migration] Context impl PostgresqlImpl.
INFO  [alembic.runtime.migration] Will assume transactional DDL.
INFO  [alembic.runtime.migration] Running upgrade 5b0e9adcef9c -> 4465047574b1, increase max dataset schema size


In [7]:
MLFLOW_TRACKING_URI, MLFLOW_S3_BUCKET

('postgresql://devadmin:devpassword123@terraform-20240803194049581800000002.cfmikes0cgsk.eu-west-1.rds.amazonaws.com:5432/devdb',
 'dev-student-performance-model-storage-bucket')

In [8]:
# Set the tracking URI to the RDS PostgreSQL instance created with terraform
mlflow.set_tracking_uri(MLFLOW_TRACKING_URI)

# Set the artifact location to the S3 bucket created with terraform
artifact_location = f"s3://{MLFLOW_S3_BUCKET}/mlflow-artifacts"

In [52]:
client = MlflowClient()
experiment_name = "xgboost_hyperoptimization"
mlflow.create_experiment(experiment_name, artifact_location=artifact_location)

'5'

In [16]:
data = pd.read_csv(DATA_PATH)

In [17]:
data.head()

Unnamed: 0,StudentID,Age,Gender,Ethnicity,ParentalEducation,StudyTimeWeekly,Absences,Tutoring,ParentalSupport,Extracurricular,Sports,Music,Volunteering,GPA,GradeClass
0,1001,17,1,0,2,19.833723,7,1,2,0,0,1,0,2.929196,2.0
1,1002,18,0,0,1,15.408756,0,0,1,0,0,0,0,3.042915,1.0
2,1003,15,0,2,3,4.21057,26,0,2,0,0,0,0,0.112602,4.0
3,1004,17,1,0,3,10.028829,14,0,3,1,0,0,0,2.054218,3.0
4,1005,17,1,0,2,4.672495,17,1,3,0,0,0,0,1.288061,4.0


In [35]:
TARGET_COLUMNS = ["GPA", "GradeClass"]
x = data.drop(TARGET_COLUMNS, axis=1)
y = data["GPA"]
stratify_col = data["GradeClass"]

In [39]:
x_train, x_val, y_train, y_val = train_test_split(
    x, y, test_size=0.2, stratify=stratify_col, random_state=SEED
)

In [56]:
# Define the hyperparameter space
search_space = {
    'max_depth': hp.quniform('max_depth', 3, 10, 1),
    'learning_rate': hp.loguniform('learning_rate', -5, 0),
    'n_estimators': hp.quniform('n_estimators', 100, 500, 50),
    'subsample': hp.uniform('subsample', 0.7, 1.0),
    'colsample_bytree': hp.uniform('colsample_bytree', 0.7, 1.0)
}

In [47]:
train_mtx = xgb.DMatrix(x_train, label=y_train)
valid_mtx = xgb.DMatrix(x_val, label=y_val)

In [20]:
def objective(params: Dict, train_mtx: xgb.DMatrix, valid_mtx: xgb.DMatrix) -> Dict:
    """
    Objective function for training an XGBoost model using specified hyperparameters.

    This function trains an XGBoost model with the provided hyperparameters, logs the model and its performance metrics to MLflow, 
    and returns the loss and status for hyperparameter optimization.

    Parameters:
    params (Dict): A dictionary of hyperparameters for training the XGBoost model. Expected keys include 'max_depth' and 'n_estimators'.
    train_mtx (xgb.DMatrix): Training data matrix in XGBoost DMatrix format.
    valid_mtx (xgb.DMatrix): Validation data matrix in XGBoost DMatrix format.

    Returns:
    Dict: A dictionary containing the loss (root mean squared error) and status for hyperparameter optimization.
    """
    
    with mlflow.start_run():
        mlflow.set_tag("model", "xgboost")

        params['max_depth'] = int(params['max_depth'])
        params['n_estimators'] = int(params['n_estimators'])

        xg_booster = xgb.train(
            params=params,
            dtrain=train_mtx,
            num_boost_round=1000,
            evals=[(valid_mtx, 'validation')],
            early_stopping_rounds=50
        )
        
        # Predict and evaluate the model
        y_pred = xg_booster.predict(valid_mtx)
        rmse = mean_squared_error(y_val, y_pred, squared=False)
        
        # Log parameters and metrics
        mlflow.log_params(params)
        mlflow.log_metric("rmse", rmse)

        return {'loss': rmse, 'status': STATUS_OK}

In [66]:
# Initialize MLflow
current_experiment = mlflow.set_experiment(experiment_name)
experiment_id = current_experiment.experiment_id

<Experiment: artifact_location='s3://dev-student-performance-model-storage-bucket/mlflow-artifacts', creation_time=1722722303726, experiment_id='5', last_update_time=1722722303726, lifecycle_stage='active', name='xgboost_hyperoptimization', tags={}>

In [69]:
best_result = fmin(
    fn=lambda params: objective(params, train_mtx, valid_mtx),
    space=search_space,
    algo=tpe.suggest,
    max_evals=50,
    trials=Trials()
)

[0]	validation-rmse:0.68720                                                                                                                                                                                           
[1]	validation-rmse:0.68431                                                                                                                                                                                           
[2]	validation-rmse:0.54683                                                                                                                                                                                           
[3]	validation-rmse:0.53516                                                                                                                                                                                           
[4]	validation-rmse:0.45097                                                                                                                 

Parameters: { "n_estimators" } are not used.




[23]	validation-rmse:0.29378                                                                                                                                                                                          
[24]	validation-rmse:0.29370                                                                                                                                                                                          
[25]	validation-rmse:0.29356                                                                                                                                                                                          
[26]	validation-rmse:0.29340                                                                                                                                                                                          
[27]	validation-rmse:0.29338                                                                                                                




[0]	validation-rmse:0.42388                                                                                                                                                                                           
[1]	validation-rmse:0.41027                                                                                                                                                                                           
[2]	validation-rmse:0.35114                                                                                                                                                                                           
[3]	validation-rmse:0.34964                                                                                                                                                                                           
[4]	validation-rmse:0.34506                                                                                                                 

Parameters: { "n_estimators" } are not used.




[35]	validation-rmse:0.35238                                                                                                                                                                                          
[36]	validation-rmse:0.35233                                                                                                                                                                                          
[37]	validation-rmse:0.35233                                                                                                                                                                                          
[38]	validation-rmse:0.35245                                                                                                                                                                                          
[39]	validation-rmse:0.35247                                                                                                                




[0]	validation-rmse:0.88200                                                                                                                                                                                           
[1]	validation-rmse:0.88025                                                                                                                                                                                           
[2]	validation-rmse:0.85768                                                                                                                                                                                           
[3]	validation-rmse:0.85593                                                                                                                                                                                           
[4]	validation-rmse:0.83387                                                                                                                 

Parameters: { "n_estimators" } are not used.




[74]	validation-rmse:0.29405                                                                                                                                                                                          
[75]	validation-rmse:0.29126                                                                                                                                                                                          
[76]	validation-rmse:0.28869                                                                                                                                                                                          
[77]	validation-rmse:0.28772                                                                                                                                                                                          
[78]	validation-rmse:0.28503                                                                                                                




[0]	validation-rmse:0.89195                                                                                                                                                                                           
[1]	validation-rmse:0.89107                                                                                                                                                                                           
[2]	validation-rmse:0.87792                                                                                                                                                                                           
[3]	validation-rmse:0.86506                                                                                                                                                                                           
[4]	validation-rmse:0.85241                                                                                                                 

Parameters: { "n_estimators" } are not used.




[29]	validation-rmse:0.64328                                                                                                                                                                                          
[30]	validation-rmse:0.63518                                                                                                                                                                                          
[31]	validation-rmse:0.62700                                                                                                                                                                                          
[32]	validation-rmse:0.61937                                                                                                                                                                                          
[33]	validation-rmse:0.61843                                                                                                                




[0]	validation-rmse:0.89875                                                                                                                                                                                           
[1]	validation-rmse:0.89840                                                                                                                                                                                           
[2]	validation-rmse:0.89208                                                                                                                                                                                           
[3]	validation-rmse:0.89146                                                                                                                                                                                           
[4]	validation-rmse:0.88501                                                                                                                 

Parameters: { "n_estimators" } are not used.




[47]	validation-rmse:0.70492                                                                                                                                                                                          
[48]	validation-rmse:0.70439                                                                                                                                                                                          
[49]	validation-rmse:0.69967                                                                                                                                                                                          
[50]	validation-rmse:0.69498                                                                                                                                                                                          
[51]	validation-rmse:0.69038                                                                                                                




[0]	validation-rmse:0.79530                                                                                                                                                                                           
[1]	validation-rmse:0.79010                                                                                                                                                                                           
[2]	validation-rmse:0.69348                                                                                                                                                                                           
[3]	validation-rmse:0.68421                                                                                                                                                                                           
[4]	validation-rmse:0.60366                                                                                                                 

Parameters: { "n_estimators" } are not used.




[46]	validation-rmse:0.22840                                                                                                                                                                                          
[47]	validation-rmse:0.22809                                                                                                                                                                                          
[48]	validation-rmse:0.22820                                                                                                                                                                                          
[49]	validation-rmse:0.22822                                                                                                                                                                                          
[50]	validation-rmse:0.22873                                                                                                                




[0]	validation-rmse:0.73036                                                                                                                                                                                           
[1]	validation-rmse:0.72385                                                                                                                                                                                           
[2]	validation-rmse:0.60071                                                                                                                                                                                           
[3]	validation-rmse:0.50906                                                                                                                                                                                           
[4]	validation-rmse:0.44040                                                                                                                 

Parameters: { "n_estimators" } are not used.




[30]	validation-rmse:0.26852
[31]	validation-rmse:0.26821                                                                                                                                                                                          
[32]	validation-rmse:0.26835                                                                                                                                                                                          
[33]	validation-rmse:0.26845                                                                                                                                                                                          
[34]	validation-rmse:0.26829                                                                                                                                                                                          
[35]	validation-rmse:0.26823                                                                                   




[0]	validation-rmse:0.87761                                                                                                                                                                                           
[1]	validation-rmse:0.85215                                                                                                                                                                                           
[2]	validation-rmse:0.82775                                                                                                                                                                                           
[3]	validation-rmse:0.80390                                                                                                                                                                                           
[4]	validation-rmse:0.78091                                                                                                                 

Parameters: { "n_estimators" } are not used.




[47]	validation-rmse:0.32041                                                                                                                                                                                          
[48]	validation-rmse:0.31608                                                                                                                                                                                          
[49]	validation-rmse:0.31168                                                                                                                                                                                          
[50]	validation-rmse:0.30733                                                                                                                                                                                          
[51]	validation-rmse:0.30335                                                                                                                




[0]	validation-rmse:0.89033                                                                                                                                                                                           
[1]	validation-rmse:0.88960                                                                                                                                                                                           
[2]	validation-rmse:0.87446                                                                                                                                                                                           
[3]	validation-rmse:0.87312                                                                                                                                                                                           
[4]	validation-rmse:0.85809                                                                                                                 

Parameters: { "n_estimators" } are not used.




[27]	validation-rmse:0.65323                                                                                                                                                                                          
[28]	validation-rmse:0.64317                                                                                                                                                                                          
[29]	validation-rmse:0.63377                                                                                                                                                                                          
[30]	validation-rmse:0.62441                                                                                                                                                                                          
[31]	validation-rmse:0.61487                                                                                                                




[0]	validation-rmse:0.89911                                                                                                                                                                                           
[1]	validation-rmse:0.89888                                                                                                                                                                                           
[2]	validation-rmse:0.89254                                                                                                                                                                                           
[3]	validation-rmse:0.89179                                                                                                                                                                                           
[4]	validation-rmse:0.88555                                                                                                                 

Parameters: { "n_estimators" } are not used.




[23]	validation-rmse:0.79715                                                                                                                                                                                          
[24]	validation-rmse:0.79673                                                                                                                                                                                          
[25]	validation-rmse:0.79146                                                                                                                                                                                          
[26]	validation-rmse:0.79109                                                                                                                                                                                          
[27]	validation-rmse:0.78567                                                                                                                




[0]	validation-rmse:0.79184                                                                                                                                                                                           
[1]	validation-rmse:0.78448                                                                                                                                                                                           
[2]	validation-rmse:0.68654                                                                                                                                                                                           
[3]	validation-rmse:0.67821                                                                                                                                                                                           
[4]	validation-rmse:0.59685                                                                                                                 

Parameters: { "n_estimators" } are not used.




[55]	validation-rmse:0.21803                                                                                                                                                                                          
[56]	validation-rmse:0.21819                                                                                                                                                                                          
[57]	validation-rmse:0.21803                                                                                                                                                                                          
[58]	validation-rmse:0.21868                                                                                                                                                                                          
[59]	validation-rmse:0.21855                                                                                                                




[0]	validation-rmse:0.89253                                                                                                                                                                                           
[1]	validation-rmse:0.88010                                                                                                                                                                                           
[2]	validation-rmse:0.86798                                                                                                                                                                                           
[3]	validation-rmse:0.85579                                                                                                                                                                                           
[4]	validation-rmse:0.84379                                                                                                                 

Parameters: { "n_estimators" } are not used.




[44]	validation-rmse:0.52623                                                                                                                                                                                          
[45]	validation-rmse:0.51997                                                                                                                                                                                          
[46]	validation-rmse:0.51424                                                                                                                                                                                          
[47]	validation-rmse:0.50829                                                                                                                                                                                          
[48]	validation-rmse:0.50266                                                                                                                

Parameters: { "n_estimators" } are not used.




[26]	validation-rmse:0.68160                                                                                                                                                                                          
[27]	validation-rmse:0.68041                                                                                                                                                                                          
[28]	validation-rmse:0.67114                                                                                                                                                                                          
[29]	validation-rmse:0.66230                                                                                                                                                                                          
[30]	validation-rmse:0.65345                                                                                                                




[0]	validation-rmse:0.86255                                                                                                                                                                                           
[1]	validation-rmse:0.85983                                                                                                                                                                                           
[2]	validation-rmse:0.81779                                                                                                                                                                                           
[3]	validation-rmse:0.81414                                                                                                                                                                                           
[4]	validation-rmse:0.77494                                                                                                                 

Parameters: { "n_estimators" } are not used.




[44]	validation-rmse:0.27168                                                                                                                                                                                          
[45]	validation-rmse:0.27096                                                                                                                                                                                          
[46]	validation-rmse:0.27023                                                                                                                                                                                          
[47]	validation-rmse:0.26612                                                                                                                                                                                          
[48]	validation-rmse:0.26534                                                                                                                




[0]	validation-rmse:0.73214                                                                                                                                                                                           
[1]	validation-rmse:0.72157                                                                                                                                                                                           
[2]	validation-rmse:0.58886                                                                                                                                                                                           
[3]	validation-rmse:0.57528                                                                                                                                                                                           
[4]	validation-rmse:0.47643                                                                                                                 

Parameters: { "n_estimators" } are not used.




[60]	validation-rmse:0.21214
[61]	validation-rmse:0.21257                                                                                                                                                                                          
[62]	validation-rmse:0.21286                                                                                                                                                                                          
[63]	validation-rmse:0.21288                                                                                                                                                                                          
[64]	validation-rmse:0.21341                                                                                                                                                                                          
[65]	validation-rmse:0.21345                                                                                   




[0]	validation-rmse:0.89858                                                                                                                                                                                           
[1]	validation-rmse:0.89834                                                                                                                                                                                           
[2]	validation-rmse:0.89152                                                                                                                                                                                           
[3]	validation-rmse:0.89075                                                                                                                                                                                           
[4]	validation-rmse:0.88388                                                                                                                 

Parameters: { "n_estimators" } are not used.




[26]	validation-rmse:0.78258                                                                                                                                                                                          
[27]	validation-rmse:0.78222                                                                                                                                                                                          
[28]	validation-rmse:0.77657                                                                                                                                                                                          
[29]	validation-rmse:0.77110                                                                                                                                                                                          
[30]	validation-rmse:0.76573                                                                                                                

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[0]	validation-rmse:0.87494                                                                                                                                                                                           
[1]	validation-rmse:0.87266                                                                                                                                                                                           
[2]	validation-rmse:0.84412                                                                                                                                                                                           
[3]	validation-rmse:0.81596                                                                                                                                                                                           
[4]	validation-rmse:0.78986                                                                                                                 

Parameters: { "n_estimators" } are not used.




[52]	validation-rmse:0.30054                                                                                                                                                                                          
[53]	validation-rmse:0.29704                                                                                                                                                                                          
[54]	validation-rmse:0.29341                                                                                                                                                                                          
[55]	validation-rmse:0.29030                                                                                                                                                                                          
[56]	validation-rmse:0.28864                                                                                                                




[0]	validation-rmse:0.45203                                                                                                                                                                                           
[1]	validation-rmse:0.41833                                                                                                                                                                                           
[2]	validation-rmse:0.29391                                                                                                                                                                                           
[3]	validation-rmse:0.27371                                                                                                                                                                                           
[4]	validation-rmse:0.26496                                                                                                                 

Parameters: { "n_estimators" } are not used.




[41]	validation-rmse:0.28165                                                                                                                                                                                          
[42]	validation-rmse:0.28256                                                                                                                                                                                          
[43]	validation-rmse:0.28237                                                                                                                                                                                          
[44]	validation-rmse:0.28281                                                                                                                                                                                          
[45]	validation-rmse:0.28341                                                                                                                




[0]	validation-rmse:0.83165                                                                                                                                                                                           
[1]	validation-rmse:0.82768                                                                                                                                                                                           
[2]	validation-rmse:0.75865                                                                                                                                                                                           
[3]	validation-rmse:0.75245                                                                                                                                                                                           
[4]	validation-rmse:0.69204                                                                                                                 

Parameters: { "n_estimators" } are not used.




[51]	validation-rmse:0.21590                                                                                                                                                                                          
[52]	validation-rmse:0.21544                                                                                                                                                                                          
[53]	validation-rmse:0.21537                                                                                                                                                                                          
[54]	validation-rmse:0.21496                                                                                                                                                                                          
[55]	validation-rmse:0.21503                                                                                                                




[0]	validation-rmse:0.88535                                                                                                                                                                                           
[1]	validation-rmse:0.86667                                                                                                                                                                                           
[2]	validation-rmse:0.84835                                                                                                                                                                                           
[3]	validation-rmse:0.83024                                                                                                                                                                                           
[4]	validation-rmse:0.81294                                                                                                                 

Parameters: { "n_estimators" } are not used.




[73]	validation-rmse:0.33219                                                                                                                                                                                          
[74]	validation-rmse:0.32987                                                                                                                                                                                          
[75]	validation-rmse:0.32767                                                                                                                                                                                          
[76]	validation-rmse:0.32542                                                                                                                                                                                          
[77]	validation-rmse:0.32312                                                                                                                




[0]	validation-rmse:0.77555                                                                                                                                                                                           
[1]	validation-rmse:0.76855                                                                                                                                                                                           
[2]	validation-rmse:0.65868                                                                                                                                                                                           
[3]	validation-rmse:0.64701                                                                                                                                                                                           
[4]	validation-rmse:0.55819                                                                                                                 

Parameters: { "n_estimators" } are not used.




[42]	validation-rmse:0.22377                                                                                                                                                                                          
[43]	validation-rmse:0.22389                                                                                                                                                                                          
[44]	validation-rmse:0.22452                                                                                                                                                                                          
[45]	validation-rmse:0.22442                                                                                                                                                                                          
[46]	validation-rmse:0.22401                                                                                                                




[0]	validation-rmse:0.89884                                                                                                                                                                                           
[1]	validation-rmse:0.89833                                                                                                                                                                                           
[2]	validation-rmse:0.89218                                                                                                                                                                                           
[3]	validation-rmse:0.88608                                                                                                                                                                                           
[4]	validation-rmse:0.88004                                                                                                                 

Parameters: { "n_estimators" } are not used.




[77]	validation-rmse:0.58911                                                                                                                                                                                          
[78]	validation-rmse:0.58586                                                                                                                                                                                          
[79]	validation-rmse:0.58259                                                                                                                                                                                          
[80]	validation-rmse:0.57928                                                                                                                                                                                          
[81]	validation-rmse:0.57613                                                                                                                

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



[0]	validation-rmse:0.87817                                                                                                                                                                                           
[1]	validation-rmse:0.87637                                                                                                                                                                                           
[2]	validation-rmse:0.84999                                                                                                                                                                                           
[3]	validation-rmse:0.84757                                                                                                                                                                                           
[4]	validation-rmse:0.82203                                                                                                                 

Parameters: { "n_estimators" } are not used.




[66]	validation-rmse:0.28735                                                                                                                                                                                          
[67]	validation-rmse:0.28431                                                                                                                                                                                          
[68]	validation-rmse:0.28148                                                                                                                                                                                          
[69]	validation-rmse:0.27887                                                                                                                                                                                          
[70]	validation-rmse:0.27637                                                                                                                




100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [03:06<00:00,  3.73s/trial, best loss: 0.19897295513590274]


In [75]:
runs = client.search_runs(
    experiment_ids=[experiment_id],
    filter_string="",
    run_view_type=ViewType.ACTIVE_ONLY,
    max_results=100,
    order_by=["metrics.rmse ASC"]
)

In [81]:
best_run = runs[0]
print(f"Best run ID: {best_run.info.run_id}, RMSE: {best_run.data.metrics['rmse']}")

Best run ID: 4ded625e45f6469c9bfc8510956027a3, RMSE: 0.19897295513590274


In [82]:
best_run.data.params

{'colsample_bytree': '0.7084146376368476',
 'learning_rate': '0.02287865756458056',
 'max_depth': '3',
 'n_estimators': '400',
 'subsample': '0.8077647505739515'}

In [10]:
def load_data(data_path: str) -> [pd.DataFrame, pd.DataFrame, pd.Series, pd.Series]:
    """Load the data and split it into training & validation
    
    Parameters:
    data_path (str): The path to the data file. The file can be in CSV format.

    Raises:
    FileNotFoundError: If the file does not exist at the specified path.
    ValueError: If the file format is not supported.
    """

    try:
        if data_path.endswith('.csv'):
            data = pd.read_csv(data_path)
        else:
            raise ValueError("Unsupported file format. Please provide a CSV file.")
    except FileNotFoundError as e:
        raise FileNotFoundError(f"The file at path {data_path} was not found.") from e


    x = data.drop([TARGET_COLUMN, STRATIFY_COLUMN], axis=1)
    y = data[TARGET_COLUMN]
    stratify_data = data[STRATIFY_COLUMN]

    x_train, x_val, y_train, y_val = train_test_split(
        x, y, test_size=0.2, stratify=stratify_data, random_state=SEED
    )

    return data

In [11]:
def get_hyperparam_space() -> Dict:
    """Define the hyperparameter space
    
    Returns:
    search_space (dict): Dictionary with defined 
    hyperparameters space
    """

    search_space = {
        'max_depth': hp.quniform('max_depth', 3, 10, 1),
        'learning_rate': hp.loguniform('learning_rate', -5, 0),
        'n_estimators': hp.quniform('n_estimators', 100, 500, 50),
        'subsample': hp.uniform('subsample', 0.7, 1.0),
        'colsample_bytree': hp.uniform('colsample_bytree', 0.7, 1.0)
    }

    return search_space

In [16]:
def upload_to_s3(s3: boto3.session.Session, bucket: str, xgb_model: xgb.Booster) -> str:
    """Save the XGBooster

    Parameters:
    s3 (boto3.session.Session): Opened client session allowing to save the model in S3
    bucket (str): Bucket path
    xgb_model (xgb.Booster): Trained xgbooster

    Raises:
    FileNotFoundError: If the file does not exist at the specified path.
    NoCredentialsError: If the credentials do not allow to open the session.

    Return:
    saved_path (str): path to the model in s3
    """
    
    try:
        # Save the model to an in-memory buffer
        buffer = io.BytesIO()
        xgb_model.save_model(buffer)
        # Reset buffer position to the beginning
        buffer.seek(0)
        todays_date = datetime.today().strftime('%Y_%m_%d')
        s3_file = f"xgboost_{todays_date}.model"
        s3.upload_file(buffer, bucket, s3_file)
        saved_path = "s3://{bucket}/{s3_file}"
        print(f"Upload Successful {s3_file} to {saved_path}")
    except FileNotFoundError:
        print(f"The file {local_file} was not found")
    except NoCredentialsError:
        print("Credentials not available")

    return saved_path

In [24]:
def save_trained_model_mlflow(xgb_model: xgb.Booster, params: Dict, saved_path: str) -> [str, str]:
    """
    Logs a trained XGBoost model and its parameters to MLflow and registers it with the given path.

    This function logs the model's hyperparameters, artifacts, and S3 path to MLflow,
    and registers the model under the given saved path. It starts a new MLflow run for these operations.

    Parameters:
    xgb_model (xgb.Booster): The trained XGBoost model to be logged.
    params (Dict): A dictionary of hyperparameters used for training the model.
    saved_path (str): The S3 path where the model is stored.

    Returns:
    run_id (str): The ID of the MLflow run.
    experiment_id (str): The ID of the MLflow experiment.

    Logs:
    - Parameters used for training the model.
    - The model artifact.
    - The S3 path where the model is stored.
    - Registers the model in MLflow.
    """

    # Log the model to MLflow
    with mlflow.start_run() as run:
        mlflow.log_params(params=params)
        
        # Log the model
        mlflow.log_artifact(model_path, artifact_path="model")
    
        # Optionally log the S3 path where the model is stored
        mlflow.log_param("s3_path", saved_path)
        
        # Log the model
        mlflow.xgboost.log_model(xgb_model, "student_performance_regressor")
        
        # Register the model
        model_uri = f"runs:/{run.info.run_id}/student_performance_regressor"
        mlflow.register_model(model_uri=model_uri, name="student_performance_regressor")
    
        run_id = run.info.run_id
        experiment_id = run.info.experiment_id
        print(f"Model logged in run {run_id} of experiment {experiment_id}")

    return run_id, experiment_id

In [25]:
def move_model_staging(client: MlflowClient, model_name: str, run_id: str):
    """
    Registers the model version from a specific run and promotes it to the staging phase in MLflow.

    This function registers a model version using the provided model name and run ID, 
    and then transitions the model version to the staging phase in MLflow Model Registry.

    Parameters:
    client (MlflowClient): The MLflow client used to interact with the MLflow tracking server.
    model_name (str): The name of the model to be registered and promoted.
    run_id (str): The ID of the MLflow run from which the model version is to be registered.

    Returns:
    None

    Actions:
    - Registers the model version from the given run ID.
    - Promotes the registered model version to the "Staging" stage in MLflow Model Registry.
    """
    
    
    model_version = client.create_model_version(
        name=model_name,
        source=f"runs:/{run_id}/{model_name}",
        run_id=run_id
    )
    
    # Transition the model version to staging
    client.transition_model_version_stage(
        name=model_name,
        version=model_version.version,
        stage="Staging"
    )
    
    print(f"Model version {model_version.version} promoted to staging")

In [32]:
t = set()
t.add((1,2))
t.add("\\")
t.add("/")

In [29]:
(1, 2) in t

True

In [33]:
t

{(1, 2), '/', '\\'}