In [1]:
import tensorflow as tf
import tensorflow_addons as tfa
import numpy as np

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split

import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = '' #'true'

In [2]:
def make_batch(iterable, n=1):
    l = len(iterable)
    for ndx in range(0, l, n):
        yield iterable[ndx:min(ndx + n, l)]    

In [3]:
class DHDT(tf.Module):
    
    def __init__(
            self,
            depth=3,
            function_representation_type = 3,
            number_of_variables = 5,
            learning_rate=1e-2,
            loss='binary_crossentropy',
            random_seed=42,
            verbosity=1):    
        
        
        self.depth = depth
        self.learning_rate = learning_rate
        self.loss = tf.keras.losses.get(loss)
        self.seed = random_seed
        self.verbosity = verbosity
        self.function_representation_type = function_representation_type
        self.number_of_variables = number_of_variables
        
        tf.random.set_seed(self.seed)
        
        function_representation_length = ( 
          (2 ** self.depth - 1) * 2 + (2 ** self.depth)  if self.function_representation_type == 1 
          else (2 ** self.depth - 1) + ((2 ** self.depth - 1) * self.number_of_variables) + (2 ** self.depth) if self.function_representation_type == 2 
          else ((2 ** self.depth - 1) * self.number_of_variables * 2) + (2 ** self.depth)  if self.function_representation_type >= 3 
          else None
                                      )        
        
        self.dt_params =  tf.Variable(tf.keras.initializers.GlorotUniform(seed=self.seed)(shape=(function_representation_length,)),
                                      trainable=True,
                                      name='dt_params')
        
        print(self.dt_params.shape)
        self.internal_nodes, self.leaf_nodes = self.get_shaped_parameters_for_decision_tree()
            
        
    def fit(self, X, y, batch_size=32, epochs=100, early_stopping_epochs=5):
        
        
        for current_epoch in range(epochs):
            tf.random.set_seed(self.seed + current_epoch)
            X = tf.random.shuffle(X, seed=self.seed + current_epoch)
            tf.random.set_seed(self.seed + current_epoch)
            y = tf.random.shuffle(y, seed=self.seed + current_epoch)
            
            for index, (X_batch, y_batch) in enumerate(zip(make_batch(X, batch_size), make_batch(y, batch_size))):
                current_loss = self.backward(X_batch, y_batch)
                loss_list.append(float(current_loss))
                
                if self.verbosity > 2:
                    batch_idx = (index+1)*batch_size
                    msg = "Epoch: {:02d} | Batch: {:03d} | Loss: {:.5f} |"
                    print(msg.format(epoch, batch_idx, current_loss))                   
                          
            msg = "Epoch: {:02d} | Loss: {:.5f} |"
            print(msg.format(epoch, np.mean(loss_list)))       
                          
    def forward(self, X):
        X = tf.dtypes.cast(tf.convert_to_tensor(X), tf.float32)       

        maximum_depth = self.depth
        leaf_node_num_ = 2 ** maximum_depth
        internal_node_num_ = 2 ** maximum_depth - 1

        function_values_dhdt = tf.vectorized_map(self.calculate_function_value_from_vanilla_decision_tree_parameter_single_sample_wrapper(leaf_node_num_, 
                                                                                                                                       internal_node_num_, 
                                                                                                                                       maximum_depth,
                                                                                                                                       self.number_of_variables), 
                                                                                                                                       X)

        return function_values_dhdt  
           
    def predict(self, X):
        return forward(X)
        
    def backward(self, x,y):
        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)#tf.compat.v1.train.GradientDescentOptimizer(learning_rate=0.01)
        with tf.GradientTape() as tape:
            predicted = self.forward(x)
            current_loss = self.loss(y, predicted)
            
        print('current_loss', current_loss)
        print('self.dt_params', self.dt_params)
        grads = tape.gradient(current_loss, self.dt_params)
        #optimizer.apply_gradients(zip(grads, self.dt_params),
        #                          global_step=tf.compat.v1.train.get_or_create_global_step())     
        
        print('grads', grads)
        print('self.dt_params', self.dt_params)
        optimizer.apply_gradients(zip(grads, self.dt_params))
        self.internal_nodes, self.leaf_nodes = self.get_shaped_parameters_for_decision_tree()
                  
        return current_loss
        
    
    def get_shaped_parameters_for_decision_tree(self):

        internal_node_num_ = 2 ** self.depth - 1 
        leaf_node_num_ = 2 ** self.depth

        if self.function_representation_type == 1:

            splits_coeff = self.dt_params[:internal_node_num_]
            splits_coeff = tf.clip_by_value(splits_coeff, clip_value_min=0, clip_value_max=1)
            splits_coeff_list = tf.split(splits_coeff, internal_node_num_)
            splits_index = tf.cast(tf.clip_by_value(tf.round(self.dt_params[internal_node_num_:internal_node_num_*2]), clip_value_min=0, clip_value_max=self.number_of_variables-1), tf.int64)
            splits_index_list = tf.split(splits_index, internal_node_num_)

            splits_list = []
            for values_node, indices_node in zip(splits_coeff_list, splits_index_list):
                sparse_tensor = tf.sparse.SparseTensor(indices=tf.expand_dims(indices_node, axis=1), values=values_node, dense_shape=[self.number_of_variables])
                dense_tensor = tf.sparse.to_dense(sparse_tensor)
                splits_list.append(dense_tensor)             

            splits = tf.stack(splits_list)            

            leaf_classes = self.dt_params[internal_node_num_*2:]  
            leaf_classes = tf.clip_by_value(leaf_classes, clip_value_min=0, clip_value_max=1)

        elif self.function_representation_type == 2:

            split_values_num_params = internal_node_num_ 
            split_index_num_params = self.number_of_variables * internal_node_num_
            leaf_classes_num_params = leaf_node_num_ 

            split_values = self.dt_params[:split_values_num_params]
            split_values_list_by_internal_node = tf.split(split_values, internal_node_num_)

            split_index_array = self.dt_params[split_values_num_params:split_values_num_params+split_index_num_params]    
            split_index_list_by_internal_node = tf.split(split_index_array, internal_node_num_)
            split_index_list_by_internal_node_by_decision_sparsity = []
            for tensor in split_index_list_by_internal_node:
                split_index_list_by_internal_node_by_decision_sparsity.append(split_tensor)
            split_index_list_by_internal_node_by_decision_sparsity_argmax = tf.split(tf.argmax(split_index_list_by_internal_node_by_decision_sparsity, axis=2), internal_node_num_)
            split_index_list_by_internal_node_by_decision_sparsity_argmax_new = []
            for tensor in split_index_list_by_internal_node_by_decision_sparsity_argmax:
                tensor_squeeze = tf.squeeze(tensor, axis=0)
                split_index_list_by_internal_node_by_decision_sparsity_argmax_new.append(tensor_squeeze)
            split_index_list_by_internal_node_by_decision_sparsity_argmax = split_index_list_by_internal_node_by_decision_sparsity_argmax_new    
            dense_tensor_list = []
            for indices_node, values_node in zip(split_index_list_by_internal_node_by_decision_sparsity_argmax,  split_values_list_by_internal_node):
                sparse_tensor = tf.sparse.SparseTensor(indices=tf.expand_dims(indices_node, axis=1), values=values_node, dense_shape=[self.number_of_variables])
                dense_tensor = tf.sparse.to_dense(sparse_tensor)
                dense_tensor_list.append(dense_tensor) 
            splits = tf.stack(dense_tensor_list)

            leaf_classes_array = self.dt_params[split_values_num_params+split_index_num_params:]  
            split_index_list_by_leaf_node = tf.split(leaf_classes_array, leaf_node_num_)

            leaf_classes = tf.squeeze(tf.stack(split_index_list_by_leaf_node))

        elif self.function_representation_type >= 3:

            split_values_num_params = self.number_of_variables * internal_node_num_
            split_index_num_params = self.number_of_variables * internal_node_num_
            leaf_classes_num_params = leaf_node_num_ 

            split_values = self.dt_params[:split_values_num_params]
            split_values_list_by_internal_node = tf.split(split_values, internal_node_num_)

            split_index_array = self.dt_params[split_values_num_params:split_values_num_params+split_index_num_params]    
            split_index_list_by_internal_node = tf.split(split_index_array, internal_node_num_)         

            split_index_list_by_internal_node_max = tfa.seq2seq.hardmax(split_index_list_by_internal_node)

            splits = tf.stack(tf.multiply(split_values_list_by_internal_node, split_index_list_by_internal_node_max))

            leaf_classes_array = self.dt_params[split_values_num_params+split_index_num_params:]  
            split_index_list_by_leaf_node = tf.split(leaf_classes_array, leaf_node_num_)

            leaf_classes = tf.squeeze(tf.stack(split_index_list_by_leaf_node))



        return splits, leaf_classes


    def calculate_function_value_from_vanilla_decision_tree_parameter_single_sample_wrapper(self, leaf_node_num_, internal_node_num_, maximum_depth, number_of_variables):

        #self.internal_nodes = tf.cast(self.internal_nodes, tf.float32)
        #self.leaf_nodes = tf.cast(self.leaf_nodes, tf.float32)   

        def calculate_function_value_from_vanilla_decision_tree_parameter_single_sample(x):

            x = tf.cast(x, tf.float32)

            internal_nodes_split = tf.split(self.internal_nodes, internal_node_num_)
            internal_nodes_split_new = [[] for _ in range(maximum_depth)]
            for i, tensor in enumerate(internal_nodes_split):
                current_depth = np.ceil(np.log2((i+1)+1)).astype(np.int32)

                internal_nodes_split_new[current_depth-1].append(tf.squeeze(tensor, axis=0))

            internal_nodes_split = internal_nodes_split_new

            split_value_list = []
            for i in range(maximum_depth):
                current_depth = i+1
                num_nodes_current_layer = 2**current_depth - 1 - (2**(current_depth-1) - 1)
                split_value_list_per_depth = []
                for j in range(num_nodes_current_layer):
                    zero_identifier = tf.not_equal(internal_nodes_split[i][j], tf.zeros_like(internal_nodes_split[i][j]))
                    split_complete = tf.greater(x, internal_nodes_split[i][j])
                    split_value = tf.reduce_any(tf.logical_and(zero_identifier, split_complete))
                    split_value_filled = tf.fill( [2**(maximum_depth-current_depth)] , split_value)
                    split_value_neg_filled = tf.fill([2**(maximum_depth-current_depth)], tf.logical_not(split_value))
                    split_value_list_per_depth.append(tf.keras.backend.flatten(tf.stack([split_value_neg_filled, split_value_filled])))        
                split_value_list.append(tf.keras.backend.flatten(tf.stack(split_value_list_per_depth)))

            split_values = tf.cast(tf.reduce_all(tf.stack(split_value_list), axis=0), tf.float32)    
            leaf_classes = tf.cast(self.leaf_nodes, tf.float32)
            final_class_probability = 1-tf.reduce_max(tf.multiply(leaf_classes, split_values))                                                                                                                                            
            return final_class_probability

        return calculate_function_value_from_vanilla_decision_tree_parameter_single_sample


        

In [4]:
X, y = make_classification(
    n_samples=10_000, n_features=5, n_informative=2, n_redundant=2, random_state=42
)

train_samples = 100  # Samples used for training the models
X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    shuffle=False,
    test_size=10_000 - train_samples,
)

In [5]:
model = DHDT()

model.fit(X_train, y_train, batch_size=32, epochs=100, early_stopping_epochs=5)

2022-06-03 12:53:52.305833: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2022-06-03 12:53:52.305886: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: dws-15
2022-06-03 12:53:52.305894: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: dws-15
2022-06-03 12:53:52.306093: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 510.60.2
2022-06-03 12:53:52.306120: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 510.60.2
2022-06-03 12:53:52.306126: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 510.60.2
2022-06-03 12:53:52.306689: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in perf

(78,)
current_loss tf.Tensor(4.7070003, shape=(), dtype=float32)
self.dt_params <tf.Variable 'dt_params:0' shape=(78,) dtype=float32, numpy=
array([-0.03282875, -0.09076975, -0.00796892, -0.0531195 ,  0.17835249,
        0.17331357,  0.04504158, -0.05553168,  0.03671388, -0.11158578,
       -0.16573708,  0.03106995, -0.08166121, -0.09131939, -0.05094133,
        0.08476271, -0.01616873, -0.14997172, -0.11345824,  0.01733561,
        0.19211806, -0.04551519, -0.17774567,  0.14585008, -0.09459972,
        0.14635558,  0.05765201, -0.03145219, -0.10139881, -0.15865773,
        0.18908806, -0.1345275 , -0.07853737, -0.05179307,  0.17468913,
       -0.15274787,  0.00897281,  0.12965007, -0.1953034 ,  0.18019284,
        0.13975246, -0.04140022, -0.10971177,  0.06693865, -0.18875885,
       -0.00762086,  0.03739753,  0.07245822, -0.12163537, -0.14712685,
       -0.06302536,  0.01606283,  0.03367572,  0.18798922, -0.19577129,
        0.13074069,  0.02864893, -0.13725558, -0.15141842,  0.18969

TypeError: 'NoneType' object is not iterable