In [1]:
import numpy as np
import pandas as pd
from collections import Counter, defaultdict

In [2]:
data = np.loadtxt('wifi_localization.txt')

In [43]:
class Node:
    def __init__(self, gini):
        self.gini = gini
        self.feature_index = 0
        self.threshold = 0
        self.left = None
        self.right = None

In [3]:
pd.DataFrame(data).describe()

Unnamed: 0,0,1,2,3,4,5,6,7
count,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0,2000.0
mean,-52.3305,-55.6235,-54.964,-53.5665,-62.6405,-80.985,-81.7265,2.5
std,11.321677,3.417688,5.316186,11.471982,9.105093,6.516672,6.519812,1.118314
min,-74.0,-74.0,-73.0,-77.0,-89.0,-97.0,-98.0,1.0
25%,-61.0,-58.0,-58.0,-63.0,-69.0,-86.0,-87.0,1.75
50%,-55.0,-56.0,-55.0,-56.0,-64.0,-82.0,-83.0,2.5
75%,-46.0,-53.0,-51.0,-46.0,-56.0,-77.0,-78.0,3.25
max,-10.0,-45.0,-40.0,-11.0,-36.0,-61.0,-63.0,4.0


In [4]:
X, y = data[:,:-1], data[:,-1]

In [86]:
def get_gini_importance(y):
    frequencies = np.array([*Counter(y).values()])/len(y)
    gini = 1 - np.sum(frequencies**2)
    return gini

def get_ginis_after_split(X_slice, y, threshold):
    left_mask = X_slice < threshold
    right_mask = X_slice >= threshold
    y_left = y[left_mask]
    y_right = y[right_mask]
    left_gini = get_gini_importance(y_left)
    right_gini = get_gini_importance(y_right)
    return (len(y_left) * left_gini + len(y_right) * right_gini) / len(y)

def get_best_split_for_feature(X_slice, y_slice):
    X_slice_sorted = np.sort(X_slice)
    thresholds = np.unique((X_slice_sorted[1:] + X_slice_sorted[:-1])/2)
    current_fueature_gini = get_gini_importance(y_slice)
    best_gini = current_fueature_gini
    best_thresh = None
    for thresh in thresholds:
        gini = get_ginis_after_split(X_slice, y_slice, thresh)
#         print(gini, thresh)
        if gini < best_gini:
            best_gini = gini
            best_thresh = thresh 
    return best_gini, best_thresh

def get_split_for_dataset(X, y):
    print(f'previous gini: {get_gini_importance(y)}')
    current_fueature_gini = get_gini_importance(y)
    best_gini = current_fueature_gini
    for i in range(X.shape[1]):
        X_slice = X[:, i]
        best_gini_for_feature, best_thresh_for_feature = get_best_split_for_feature(X_slice, y)
        if best_thresh_for_feature is not None and best_gini_for_feature < best_gini:
            best_gini = best_gini_for_feature
            best_index = i
            best_thresh = best_thresh_for_feature
    return best_gini, best_index, best_thresh
            
def grow_tree(X, y, depth = 0):
    current_gini = get_gini_importance(y)
    node = Node(gini=current_gini)
    if depth < 2:
        best_gini, best_index, best_thresh = get_split_for_dataset(X, y)
        if best_index is not None and best_thresh is not None:
            left_mask = X[:, best_index] < best_thresh
            right_mask = X[:, best_index] >= best_thresh
            X_left = X[left_mask]
            X_right = X[right_mask]
            y_left = y[left_mask]
            y_right = y[right_mask]
            print(depth, best_thresh, best_index, X_right.shape)
            node.feature_index=best_index
            node.threshold=best_thresh
            node.right = grow_tree(X_right, y_right, depth+1)
            node.left = grow_tree(X_left, y_left, depth+1)
    return node

In [88]:
tree = grow_tree(X, y)

previous gini: 0.75
0 -54.5 0 (988, 7)
previous gini: 0.5029073579307971
1 -44.5 0 (450, 7)
previous gini: 0.5144881969723007
1 -59.5 4 (497, 7)


In [99]:
print(tree.gini, tree.threshold, tree.feature_index, '\n',
      tree.left.gini, tree.left.threshold, tree.left.feature_index, '\n\t',
      tree.left.left.gini, tree.left.left.threshold, tree.left.left.feature_index,'\n\t',
      tree.left.right.gini, tree.left.right.threshold, tree.left.right.feature_index,'\n',
      tree.right.gini, tree.right.threshold, tree.right.feature_index, '\n\t', 
      tree.right.left.gini, tree.right.left.threshold, tree.right.left.feature_index,'\n\t',
      tree.right.right.gini, tree.right.right.threshold, tree.right.right.feature_index,'\n',)


0.75 -54.5 0 
 0.5144881969723007 -59.5 4 
	 0.056887548308040436 0 0 
	 0.015967029541433808 0 0 
 0.5029073579307971 -44.5 0 
	 0.20225674050939035 0 0 
	 0.03492345679012354 0 0 



In [45]:
best_gini, best_index, left_mask, right_mask = get_split_for_dataset(X, y)
print(best_gini, best_index, len(X[left_mask]))
best_gini2, best_index2, left_mask2, right_mask2 = get_split_for_dataset(X, y[left_mask])
print(best_gini2, best_index2, len(X[left_mask][left_mask2]))
best_gini3, best_index3, left_mask3, right_mask3 = get_split_for_dataset(X[left_mask][left_mask2], y[left_mask][left_mask2])
print(best_gini3, best_index3)


previous gini: 0.75


ValueError: too many values to unpack (expected 4)

In [41]:
best_gini, best_index, X_left, X_right, y_left, y_right = get_split_for_dataset(X, y)
print(best_gini, best_index, len(X_left))
best_gini2, best_index2, X_left2, X_right2, y_left2, y_right2= get_split_for_dataset(X_left, y_left)
print(best_gini2, best_index2, len(X_left2))
best_gini3, best_index3, X_left3, X_right3, y_left3, y_right3 = get_split_for_dataset(X_left2, y_left2)
print(best_gini3, best_index3)


previous gini: 0.75
0.508767262485798 0 1012
previous gini: 0.5144881969723007
0.036791206581752396 4 515
previous gini: 0.056887548308040436
0.02278611356281249 3


In [None]:
def grow_tree(X, y, max_depth = 2)

In [None]:
X.Shap

In [6]:
%%time
Counter(y)

CPU times: user 536 µs, sys: 168 µs, total: 704 µs
Wall time: 710 µs


Counter({1.0: 500, 2.0: 500, 3.0: 500, 4.0: 500})

In [7]:
%%time
d = defaultdict(int)
for yy in y:
    d[yy]+=1
d

CPU times: user 570 µs, sys: 178 µs, total: 748 µs
Wall time: 752 µs


defaultdict(int, {1.0: 500, 2.0: 500, 3.0: 500, 4.0: 500})

In [23]:
np.array([*d.values()])

array([500, 500, 500, 500])

In [30]:
get_gini_importance(y)

0.75

In [8]:
# for i in range(-70, -20):
#     print(i, get_ginis_after_split(X[:,0], y, i))

In [43]:
x_sort = np.sort(X[:,1])
a=(x_sort[1:] + x_sort[:-1]) / 2

In [20]:
mask = X[:,0]>=-54.5
XX = X[:,0][mask]
mask2 = XX < -44.5
# get_ginis_after_split(X[:,0][mask], y[mask], -44.5)
get_gini_importance(y[mask][mask2]), get_ginis_after_split(XX, y[mask], -44.5)

(0.20225674050939035, 0.12604218820810487)

In [60]:
%%time
d = {}

for i in range(X.shape[1]):
    results.append(get_best_split(X[:, i], y))
results.sort(key=lambda x: x[0])
print(results)

[(0.508767262485798, -54.5), (0.5111624649300466, -56.5), (0.5298413956541929, -56.5), (0.5815969026141384, -76.5), (0.6052432784137277, -77.5), (0.6378616433611275, -57.5), (0.7411891029651735, -53.5)]
CPU times: user 239 ms, sys: 7.29 ms, total: 246 ms
Wall time: 249 ms


In [71]:
results2 = []
for i in range(1,X.shape[1]):
    mask = X[:,0]<-54.5
    x_slice = X[:,i]
    results2.append(get_best_split(x_slice[mask], y))
results2.sort(key=lambda x: x[0])
print(results2)

IndexError: boolean index did not match indexed array along dimension 0; dimension is 2000 but corresponding boolean dimension is 1012

In [76]:
mask = X[:,0]<-54.5
X[:,0][mask]>=54.5


#trzeba to zrobić rekurencyjnie

array([False, False, False, ..., False, False, False])

In [53]:
d

{'feature_0': (0.508767262485798, -54.5),
 'feature_1': (0.7411891029651735, -53.5),
 'feature_2': (0.6378616433611275, -57.5),
 'feature_3': (0.5298413956541929, -56.5),
 'feature_4': (0.5111624649300466, -56.5),
 'feature_5': (0.5815969026141384, -76.5),
 'feature_6': (0.6052432784137277, -77.5)}

In [56]:
get_best_split(X[:,0], y)

0.749624812406203 -73.5
0.749624812406203 -73.0
0.748496993987976 -72.5
0.748496993987976 -72.0
0.7473657802308078 -71.5
0.7473657802308078 -71.0
0.746766305589835 -70.5
0.746766305589835 -70.0
0.7457157258064517 -69.5
0.7457157258064517 -69.0
0.7449673683242946 -68.5
0.7449673683242946 -68.0
0.738785343443021 -67.5
0.738785343443021 -67.0
0.734196323092247 -66.5
0.734196323092247 -66.0
0.7266570823244551 -65.5
0.7266570823244551 -65.0
0.7152508248798632 -64.5
0.7152508248798632 -64.0
0.7051006949901937 -63.5
0.7051006949901937 -63.0
0.6869056277056277 -62.5
0.6869056277056277 -62.0
0.6575307983680092 -61.5
0.6575307983680092 -61.0
0.6349723874904653 -60.5
0.6349723874904653 -60.0
0.6144885171817691 -59.5
0.6144885171817691 -59.0
0.5816037598155122 -58.5
0.5816037598155122 -58.0
0.5526342464198252 -57.5
0.5526342464198252 -57.0
0.5292193886506682 -56.5
0.5292193886506682 -56.0
0.5130912124148981 -55.5
0.5130912124148981 -55.0
0.508767262485798 -54.5
0.508767262485798 -54.0
0.5132789765

(0.508767262485798, -54.5)