# Soft Decision Tree distilled from Neural Network

Based on "Distilling a Neural Network Into a Soft Decision Tree" by Nicholas Frosst, Geoffrey Hinton.

In [1]:
import catboost
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import scipy.signal as signal
import scipy.stats as stats
import uncertainty_toolbox as uct
import torch
import torch.nn as nn

from typing import Any, Dict, List, Union
from tqdm import tqdm

from sklearn.metrics import auc, precision_recall_curve, roc_auc_score
from sklearn.mixture import BayesianGaussianMixture, GaussianMixture
from sklearn.model_selection import train_test_split

from torch.utils.data import DataLoader, TensorDataset

from src.probabilistic_flow_boosting.extras.datasets.uci_dataset import UCIDataSet
from src.probabilistic_flow_boosting.tfboost.tree import EmbeddableCatBoostPriorNormal
from src.probabilistic_flow_boosting.tfboost.tfboost import TreeFlowBoost
from src.probabilistic_flow_boosting.tfboost.flow import ContinuousNormalizingFlow
from src.probabilistic_flow_boosting.pipelines.reporting.nodes import calculate_nll
from src.probabilistic_flow_boosting.pipelines.modeling.utils import setup_random_seed

from src.probabilistic_flow_boosting.tfboost.soft_decision_tree import SoftDecisionTree


pd.set_option('display.float_format', lambda x: '%.5f' % x)

In [2]:
RANDOM_SEED = 42
TRAIN = False
MODEL_FILEPATH = 'treeflow_wine.model'

setup_random_seed(RANDOM_SEED)

In [3]:
x_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()
y_train = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_train_1.txt"
).load()

x_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_features.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()
y_test = UCIDataSet(
    filepath_data = "data/01_raw/UCI/wine-quality-red/data.txt",
    filepath_index_columns = "data/01_raw/UCI/wine-quality-red/index_target.txt",
    filepath_index_rows = "data/01_raw/UCI/wine-quality-red/index_test_1.txt"
).load()

In [4]:
x_tr, x_val, y_tr, y_val = train_test_split(x_train, y_train, test_size = 0.2, random_state=RANDOM_SEED)

In [5]:
x_tr = torch.Tensor(x_tr.values)
x_val = torch.Tensor(x_val.values)
x_test = torch.Tensor(x_test.values)

y_tr = torch.Tensor(y_tr.values)
y_val = torch.Tensor(y_val.values)
y_test = torch.Tensor(y_test.values)
# , x_val, y_tr, y_val 

In [6]:
# Parameters
input_dim = x_tr.shape[1]  # the number of input dimensions
output_dim = y_tr.shape[1] # the number of outputs (i.e., # classes on MNIST)
depth = 3              # tree depth
lamda = 1e-3           # coefficient of the regularization term
lr = 1e-3              # learning rate
weight_decaly = 5e-4   # weight decay
batch_size = 128       # batch size
epochs = 150            # the number of training epochs
log_interval = 100     # the number of batches to wait before printing logs
use_cuda = False       # whether to use GPU

# Model and Optimizer
tree = SoftDecisionTree(input_dim, output_dim, depth, lamda, use_cuda)

optimizer = torch.optim.Adam(
    tree.parameters(),
    lr=lr,
    weight_decay=weight_decaly
)

# Load data
train_loader: DataLoader = DataLoader(
    dataset=TensorDataset(x_tr, y_tr),
    shuffle=True,
    batch_size=batch_size
)

val_loader: DataLoader = DataLoader(
    dataset=TensorDataset(x_val, y_val),
    shuffle=False,
    batch_size=batch_size
)

test_loader: DataLoader = DataLoader(
    dataset=TensorDataset(x_test, y_test),
    shuffle=False,
    batch_size=batch_size
)

# Utils
best_testing_acc = np.infty
testing_acc_list = []
training_loss_list = []
criterion = nn.MSELoss(reduce = 'sum')
device = torch.device("cuda" if use_cuda else "cpu")

for epoch in range(epochs):
    
    # Training
    tree.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        output, penalty = tree.forward(data, is_training_data=True)

        loss = criterion(output, target)
        loss += penalty

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Evaluating
    tree.eval()
    l = 0.

    for batch_idx, (data, target) in enumerate(val_loader):
        data, target = data.to(device), target.to(device)

        output = tree.forward(data)
        l += criterion(output, target)

    accuracy = np.sqrt(l.detach().numpy().item() / len(val_loader.dataset))

    if accuracy < best_testing_acc:
        best_testing_acc = accuracy

    msg = f"Epoch: {epoch} | Test Loss: {accuracy}"
    print(msg)
    print('-'*50)




Epoch: 0 | Test Loss: 0.562272155204594
--------------------------------------------------
Epoch: 1 | Test Loss: 0.5611798371296229
--------------------------------------------------
Epoch: 2 | Test Loss: 0.5601760044986146
--------------------------------------------------
Epoch: 3 | Test Loss: 0.5592155641342641
--------------------------------------------------
Epoch: 4 | Test Loss: 0.5582763908538477
--------------------------------------------------
Epoch: 5 | Test Loss: 0.5573468284996032
--------------------------------------------------
Epoch: 6 | Test Loss: 0.5564243781497276
--------------------------------------------------
Epoch: 7 | Test Loss: 0.555506905438887
--------------------------------------------------
Epoch: 8 | Test Loss: 0.5545917601486131
--------------------------------------------------
Epoch: 9 | Test Loss: 0.5536797193442786
--------------------------------------------------
Epoch: 10 | Test Loss: 0.5527701274570335
----------------------------------------

In [7]:
l = 0.

for batch_idx, (data, target) in enumerate(test_loader):
    data, target = data.to(device), target.to(device)

    output = tree.forward(data)
    l += criterion(output, target)

accuracy = np.sqrt(l.detach().numpy().item() / len(test_loader.dataset))
print(accuracy)

0.4626022019745783


## Exploration of the model

In [40]:
mu, _ = tree._forward(x_test) # Probability of landing in a particular node

In [45]:
mu[0]

tensor([9.9975e-01, 2.8127e-05, 8.6655e-05, 1.4065e-05, 1.2591e-05, 2.4486e-05,
        6.1152e-05, 2.3365e-05], grad_fn=<SelectBackward0>)

In [26]:
batch_size = x_test.shape[0]

X = tree._data_augment(x_test)

path_prob = tree.inner_nodes(X)
path_prob = torch.unsqueeze(path_prob, dim=2)
path_prob = torch.cat((path_prob, 1 - path_prob), dim=2)

In [27]:
_mu = X.data.new(batch_size, 1, 1).fill_(1.0)
_penalty = torch.tensor(0.0).to(tree.device)

# Iterate through internal odes in each layer to compute the final path
# probabilities and the regularization term.
begin_idx = 0
end_idx = 1

for layer_idx in range(0, tree.depth):
    _path_prob = path_prob[:, begin_idx:end_idx, :]
    print(_path_prob.shape)
    print(_path_prob[0])

    _mu = _mu.view(batch_size, -1, 1).repeat(1, 1, 2)
    _mu = _mu * _path_prob  # update path probabilities

    begin_idx = end_idx
    end_idx = begin_idx + 2 ** (layer_idx + 1)

torch.Size([160, 1, 2])
tensor([[9.9988e-01, 1.2159e-04]], grad_fn=<SelectBackward0>)
torch.Size([160, 2, 2])
tensor([[9.9990e-01, 1.0073e-04],
        [3.0492e-01, 6.9508e-01]], grad_fn=<SelectBackward0>)
torch.Size([160, 4, 2])
tensor([[9.9997e-01, 2.8133e-05],
        [8.6036e-01, 1.3964e-01],
        [3.3958e-01, 6.6042e-01],
        [7.2354e-01, 2.7646e-01]], grad_fn=<SelectBackward0>)


In [34]:
tree.leaf_node_num_

8

In [33]:
# Probability of 
_mu[0]

tensor([[9.9975e-01, 2.8127e-05],
        [8.6655e-05, 1.4065e-05],
        [1.2591e-05, 2.4486e-05],
        [6.1152e-05, 2.3365e-05]], grad_fn=<SelectBackward0>)

In [11]:
mu = _mu.view(batch_size, tree.leaf_node_num_)

In [12]:
pred = tree.leaf_nodes(mu)

In [13]:
pred

tensor([[1.4929],
        [1.4926],
        [1.4917],
        [1.4931],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4930],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4907],
        [1.4931],
        [1.4932],
        [1.4929],
        [1.4930],
        [1.4932],
        [1.4924],
        [1.4932],
        [1.4932],
        [1.4929],
        [1.4932],
        [1.4921],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4931],
        [1.4922],
        [1.4931],
        [1.4932],
        [1.4895],
        [1.4932],
        [1.4874],
        [1.4932],
        [1.4928],
        [1.4930],
        [1.4929],
        [1.4922],
        [1.4922],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4931],
        [1.4909],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4931],
        [1.4932],
        [1.4932],
        [1.4932],
        [1.4932],
        [1