In [2]:
import create_graphs
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import roc_curve, roc_auc_score

def run_decision_tree_classifier(X_train, y_train, X_val, y_val):
    """
    Trains a Decision Tree Classifier, evaluates it, and logs results to WandB.
    """
    # Initialize WandB
    wandb.init(project="xG-best-model", job_type="decision_tree", name="Decision Tree")

    # Train Decision Tree Classifier
    decision_tree_model = DecisionTreeClassifier(max_depth=10, random_state=42)
    decision_tree_model.fit(X_train, y_train)

    # Predict probabilities
    pred_probs = decision_tree_model.predict_proba(X_val)
    y_pred = pred_probs[:, 1]  # Positive class probabilities

    # Evaluate model
    fpr, tpr, _ = roc_curve(y_val, y_pred)
    auc = roc_auc_score(y_val, y_pred)

    # Log AUC and ROC Curve
    wandb.log({"AUC": auc})
    wandb.log({"ROC Curve": wandb.plot.roc_curve(y_val, pred_probs, labels=["Non-Goal", "Goal"])})

    create_wandb_graphs(y_val, y_pred, name)