<a href="https://colab.research.google.com/github/wandb/examples/blob/master/colabs/scikit/wandb_decision_tree.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


## Author: [@SauravMaheshkar](https://twitter.com/MaheshkarSaurav)

# Packages 📦 and Basic Setup
---

## Install Packages

In [None]:
%%capture
## Install Sklearn
!pip install -U scikit-learn
## Install the latest version of wandb client 🔥🔥
!pip install -q --upgrade wandb

## Project Configuration using **`wandb.config`**

In [None]:
import wandb

## Importing Libraries
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor

In [None]:
wandb.login()

In [None]:
# Initialize the run
run = wandb.init(project='simple-scikit')

# Feel free to change these and experiment !!
config = wandb.config
config.max_depth = 5
config.min_samples_split = 2
config.clf_criterion = "gini"
config.reg_criterion = "mse"
config.splitter = "best"
config.dataset = "iris"
config.test_size = 0.2
config.random_state = 42
config.labels =['setosa', 'versicolor', 'virginica']

# Update the config
wandb.config.update(config)

# 💿 Dataset
---

In [None]:
## Loading the Dataset
iris = load_iris(return_X_y = True, as_frame= True)
dataset = iris[0]
target = iris[1]

# ✍️ Model Architecture
---

## Classification

In [None]:
X, y = load_iris(return_X_y=True)
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)

clf = DecisionTreeClassifier(
    max_depth=config.max_depth,
    min_samples_split=config.min_samples_split,
    criterion=config.clf_criterion,
    splitter=config.splitter
)
clf = clf.fit(x_train,y_train)

y_pred = clf.predict(x_test)

# Visualize Confustion Matrix
wandb.sklearn.plot_confusion_matrix(y_test, y_pred, config.labels)

## Regression

In [None]:
X, y = load_iris(return_X_y=True)

x_train, x_test, y_train, y_test = train_test_split(X, y, test_size = config.test_size, random_state = config.random_state)

reg = DecisionTreeRegressor(
    max_depth=config.max_depth,
    min_samples_split=config.min_samples_split,
    criterion=config.reg_criterion,
    splitter=config.splitter
)

reg = reg.fit(x_train,y_train)

# All regression plots
wandb.sklearn.plot_regressor(reg, x_train, x_test, y_train, y_test,  model_name='DecisionTreeRegressor')

In [None]:
# Finish the W&B Process
wandb.finish()