In [4]:
import mlflow
import mlflow.sklearn
from pycaret.classification import setup, create_model, tune_model, predict_model, get_metrics
from sklearn.metrics import log_loss, f1_score
import pandas as pd

In [7]:
# Make sure that the mlflow server is running locally in the port "5000", otherwise you need to change this url below
mlflow.set_tracking_uri('http://127.0.0.1:5000')
mlflow.set_experiment('Training Notebook')

<Experiment: artifact_location='mlflow-artifacts:/241919081044893865', creation_time=1743352035842, experiment_id='241919081044893865', last_update_time=1743352035842, lifecycle_stage='active', name='Training Notebook', tags={}>

In [10]:
# load dataset
data = pd.read_parquet('../data/05_model_input/train_dataset_kobe_dev.parquet')
target_column = 'shot_made_flag'
y = data[target_column]

In [9]:
# pycaret setup
clf_setup = setup(data, target=target_column, session_id=123, fix_imbalance=True)

Unnamed: 0,Description,Value
0,Session id,123
1,Target,shot_made_flag
2,Target type,Binary
3,Original data shape,"(16228, 7)"
4,Transformed data shape,"(16743, 7)"
5,Transformed train set shape,"(11874, 7)"
6,Transformed test set shape,"(4869, 7)"
7,Numeric features,6
8,Preprocess,True
9,Imputation type,simple


In [14]:
with mlflow.start_run():
    # train logistic regression model
    lr_model = create_model('lr')
    lr_preds = predict_model(lr_model, data=data)
    lr_log_loss = log_loss(y, lr_preds['prediction_label'])
    
    # registry metrics for logistic regression model
    mlflow.log_metric('log_loss_lr', lr_log_loss)
    mlflow.sklearn.log_model(lr_log_loss, 'logistic_regression_model')
    
    # train decision hree model
    tree_model = create_model('dt')
    dt_preds = predict_model(tree_model, data=data)
    dt_log_loss = log_loss(y, dt_preds['prediction_label'])
    dt_f1 = f1_score(y, dt_preds['prediction_label'])
    
    # registry metrics for decision tree model
    mlflow.log_metric('decision_tree_lr', lr_log_loss)
    mlflow.log_metric('decision_tree_log_loss', dt_log_loss)
    mlflow.log_metric('decision_tree_f1_score', dt_f1)
    mlflow.sklearn.log_model(tree_model, 'decision_tree_model')
    
    best_model = "logistic_model" if lr_log_loss < dt_log_loss else "tree_model"
    mlflow.log_param("best_model", best_model)
    print(f"Modelo escolhido: {best_model}")
    

Unnamed: 0_level_0,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.5713,0.5892,0.5314,0.5528,0.5419,0.1393,0.1394
1,0.566,0.5904,0.5295,0.5467,0.538,0.129,0.1291
2,0.5836,0.5882,0.5351,0.5675,0.5508,0.1634,0.1637
3,0.5687,0.609,0.5498,0.5478,0.5488,0.1357,0.1357
4,0.5651,0.5773,0.5332,0.5453,0.5392,0.1276,0.1276
5,0.5836,0.6196,0.5185,0.57,0.543,0.1622,0.1628
6,0.5599,0.5856,0.5406,0.5386,0.5396,0.118,0.118
7,0.5599,0.578,0.5414,0.5394,0.5404,0.1181,0.1181
8,0.5836,0.6085,0.5451,0.567,0.5559,0.1643,0.1644
9,0.593,0.6109,0.5424,0.5787,0.56,0.182,0.1824


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
0,Logistic Regression,0.5762,0.6007,0.5467,0.5572,0.5519,0.15,0.15




Unnamed: 0_level_0,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
Fold,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1
0,0.5555,0.5348,0.5812,0.5312,0.5551,0.1127,0.1132
1,0.5264,0.507,0.5996,0.5031,0.5471,0.0587,0.0597
2,0.5229,0.5094,0.5443,0.5,0.5212,0.0475,0.0476
3,0.5396,0.525,0.6144,0.5147,0.5601,0.0851,0.0865
4,0.5484,0.5412,0.5849,0.524,0.5527,0.0995,0.1001
5,0.5739,0.5591,0.607,0.5483,0.5762,0.1501,0.1509
6,0.5475,0.526,0.6107,0.5221,0.5629,0.0999,0.1012
7,0.5458,0.533,0.6022,0.5215,0.559,0.0957,0.0967
8,0.5361,0.5255,0.5985,0.5126,0.5523,0.0769,0.0779
9,0.5366,0.5192,0.5664,0.5134,0.5386,0.0754,0.0757


Unnamed: 0,Model,Accuracy,AUC,Recall,Prec.,F1,Kappa,MCC
0,Decision Tree Classifier,0.8109,0.8484,0.861,0.7701,0.813,0.623,0.627




Modelo escolhido: tree_model
🏃 View run thoughtful-wren-618 at: http://127.0.0.1:5000/#/experiments/241919081044893865/runs/b3d8d5a688a1455ab0433181828ee47b
🧪 View experiment at: http://127.0.0.1:5000/#/experiments/241919081044893865
