In [1]:
# Base on wandb documentation: Scikit-learn integration
import os
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

from sklearn.tree import DecisionTreeClassifier
from sklearn.tree import plot_tree

from sklearn.exceptions import ConvergenceWarning
import warnings
import pickle
import wandb

from sklearn.metrics import accuracy_score, precision_score, recall_score, roc_auc_score, f1_score, confusion_matrix, roc_curve, mean_squared_error

warnings.filterwarnings("ignore", category=ConvergenceWarning)

In [2]:
# Failed to detect the name of this notebook...
os.environ['WANDB_NOTEBOOK_NAME'] = '04_opiod_wandb_tree.ipynb'

In [3]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33midiazl[0m ([33mdev_ml_ops[0m). Use [1m`wandb login --relogin`[0m to force relogin


True

In [4]:
data = 'CaseStudy_training_data.xlsx'
df = pd.read_excel(data, sheet_name='Model_data')

3. Data cleaning

In [5]:
df_cleaned = df.dropna()
df_cleaned = df_cleaned.drop_duplicates()
df_cleaned = df_cleaned.drop(['ID'], axis=1)
df_cleaned = df_cleaned.rename(columns={'rx ds': 'rx_ds'})

4. Feature Engineering

In [6]:
# Perform percentile-based bucketing for 'rx_ds'
df_features = df_cleaned.copy()
df_features['rx_ds_bucket'] = pd.qcut(
    df_cleaned['rx_ds'], 
    q=4, 
    labels=['Q1', 'Q2', 'Q3', 'Q4']
    )

In [7]:
# Create a new feature that is the sum of all the binary features
binary_cols = [col for col in df_features.columns if col not in ['OD', 'rx_ds', 'rx_ds_bucket']]
df_features['binary_sum'] = df_features[binary_cols].sum(axis=1)

# Create a new feature that is the ratio of 'rx_ds' to the sum of binary features
df_features['rx_ds_to_binary_sum'] = df_features['rx_ds'] / df_features['binary_sum']

In [8]:
# Perform one-hot encoding for 'rx_ds_bucket'
df_one_hot = pd.get_dummies(df_features['rx_ds_bucket'], prefix='rx_ds_bucket')
df_features = pd.concat([df_features, df_one_hot], axis=1)
df_features.drop(['rx_ds_bucket'], axis=1, inplace=True)

## Runs for model training

### 1. Classification - Decision Tree

In [9]:
from sklearn.model_selection import train_test_split

df_tree = df_features.copy()

X = df_tree.drop(['OD', 'rx_ds'], axis=1)
y = df_tree['OD']

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [10]:
# Defining a function to calculate the metrics for the classifier
def calculate_metrics(y_test, y_pred, model, X_test):
    metrics = {}
    metrics["accuracy"] = accuracy_score(y_test, y_pred)
    metrics["precision"] = precision_score(y_test, y_pred)
    metrics["recall"] = recall_score(y_test, y_pred)
    metrics["f1"] = f1_score(y_test, y_pred)
    metrics["roc_auc"] = roc_auc_score(y_test, model.predict_proba(X_test)[:,1])
    
    conf_matrix = confusion_matrix(y_test, y_pred)
    TN, FP, FN, TP = conf_matrix.ravel()
    
    metrics["ppv"] = TP / (TP + FP)
    metrics["npv"] = TN / (TN + FN)
    metrics["specificity"] = TN / (TN + FP)
    
    return metrics

In [None]:
# Define a function to calculate tthe metrics speciffically for the tree model
def calculate_tree_metrics(model, X_test, y_test):
    tree_metrics = {}
    tree_metrics["tree_depth"] = model.get_depth()
    tree_metrics["num_leaves"] = model.get_n_leaves()
    ccp_path = model.cost_complexity_pruning_path(X_test, y_test)
    tree_metrics["ccp_alphas"] = ccp_path.ccp_alphas
    tree_metrics["impurities"] = ccp_path.impurities
    
    return tree_metrics

In [None]:
# New wandb project and run
run = wandb.init(project='wandb-sklearn-tree', name="classifier_decision_tree")

tree_params = {
    'criterion': 'gini',
    'splitter': 'best',
    'max_depth': None,
    'min_samples_split': 2,
    'min_samples_leaf': 1,
    'random_state': 42
}

wandb.config = tree_params

In [22]:
# Initializing the model, Fitting and predicting
tree_model = DecisionTreeClassifier(**wandb.config)
tree_model.fit(X_train, y_train)
y_pred_tree = tree_model.predict(X_test)

### Visualizations

In [None]:
# Visualize the Decision Tree to nderstand the decision paths and criteria at each node
plt.figure(figsize=(20, 10))
plot_tree(tree_model, filled=True, rounded=True, class_names=["Not OD", "OD"], feature_names=X.columns)
plt.show()

In [None]:
# Show Feature Importances
feature_importances = pd.DataFrame({'Feature': X.columns, 'Importance': tree_model.feature_importances_})
feature_importances.sort_values(by='Importance', ascending=False)

### Tree characteristics and metrics

In [None]:
# How complex the model is. A deeper tree might have a higher chance of overfitting
tree_depth = tree_model.get_depth()
tree_depth

In [None]:
# Another measure for model complexity
num_leaves = tree_model.get_n_leaves()
num_leaves

In [None]:
# Extract the decision paths for specific samples
decision_paths = tree_model.decision_path(X_test)
decision_paths

In [None]:
# Pruning to improve performance, and avoiding overfitting
ccp_path = tree_model.cost_complexity_pruning_path(X_train, y_train)
ccp_path

### Logging metrics

In [23]:
# Logging the performance metrics
y_pred = tree_model.predict(X_test)
tree_metrics = calculate_metrics(y_test, y_pred, tree_model, X_test)

wandb.log(tree_metrics)

In [None]:
# Logging the tree characteristics
tree_characteristics = calculate_tree_metrics(tree_model, X_test, y_test)

wandb.log(tree_characteristics)

#### Logging artifacts

In [24]:
# Save the model
os.makedirs('models', exist_ok=True)
with open("models/tree_model.pkl", "wb") as f:
    pickle.dump(tree_model, f)

# Log the model as a versioned file
artifact = wandb.Artifact("tree_model", type="model")
artifact.add_file("models/tree_model.pkl")
wandb.log_artifact(artifact)

<Artifact log_mode>

In [17]:
# Save the data
os.makedirs('data', exist_ok=True)
datasets = {"trainig": X, "validation": y}

for name, df in datasets.items():
    df.to_csv(f'data/{name}.csv', index=False)

# Log the `data` as an artifact
artifact = wandb.Artifact('train_val_sets', type='dataset', metadata={"Source": "CaseStudy_training_data.xlsx"})
artifact.add_dir('data')
wandb.log_artifact(artifact)

[34m[1mwandb[0m: Adding directory to artifact (./data)... Done. 0.0s


<Artifact train_val_sets>

In [25]:
wandb.finish()

VBox(children=(Label(value='0.002 MB of 0.002 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
accuracy,▁
f1,▁
npv,▁
ppv,▁
precision,▁
recall,▁
roc_auc,▁
specificity,▁

0,1
accuracy,0.715
f1,0.6069
npv,0.86842
ppv,0.51163
precision,0.51163
recall,0.74576
roc_auc,0.78002
specificity,0.70213
