In [2]:
%matplotlib inline

import numpy as np
import matplotlib.pyplot as plt

from sklearn import datasets
from scipy import optimize

In [491]:
def find_entropy(Y):
    counts = {}
    for y in Y:
        if y in counts:
            counts[y] += 1
        else:
            counts[y] = 1
    s = float(len(Y))
    return 1 - sum([(i / s) ** 2 for i in counts.values()])

In [524]:
def find_predicate_by_bruteforce(X, Y):
    min_entropy = find_entropy(Y)
    if min_entropy == 0.:
        return None
    predicate = None
    n_Y = len(Y)
    for obj in X:
        for f_id, f_val in enumerate(obj):
            mask = X[:, f_id] <= f_val
            y1, y2 = Y[mask], Y[~mask]
            n_y1, n_y2 = len(y1), len(y2)
            if n_y2 == 0 or n_y2 == 0:
                continue
            e1 = find_entropy(y1) * n_y1
            e2 = find_entropy(y2) * n_y2
            local_entropy = (e1 + e2) / n_Y
            if local_entropy < min_entropy:
                min_entropy = local_entropy
                predicate = f_id, f_val
    return predicate

In [515]:
def find_predicate_by_minimization(X, Y):

    def best_split(f_val, X, Y):
        mask = X <= f_val
        y1, y2 = Y[mask], Y[~mask]
        n_y1, n_y2 = len(y1), len(y2)
        e1 = find_entropy(y1) * n_y1
        e2 = find_entropy(y2) * n_y2
        return (e1 + e2) / len(Y)

    min_entropy = find_entropy(Y)
    if min_entropy == 0.:
        return None
    predicate = None
    for f_id in range(X.shape[1]):
        X_f = X[:, f_id]
        for x0 in np.linspace(X_f.min(), X_f.max(), 10):
            res = optimize.fmin(best_split, x0, (X_f, Y), disp=False)
            xopt = res[0]
            fopt = best_split(xopt, X_f, Y)
            if fopt < min_entropy:
                min_entropy = fopt
                predicate = f_id, xopt
    return predicate

In [551]:
def find_predicate_by_optimal_compares(X, Y):
    min_impurity = find_entropy(Y)
    if min_impurity == 0.:
        return None
    Ynum = float(Y.shape[0])
    fnum = X.shape[1]
    predicate = None
    uniques, counts = np.unique(Y, return_counts=True)
    D = np.concatenate((X, Y.reshape(Y.shape[0], 1)), axis=1)
    D = D.view(','.join(['f8'] * (fnum + 1)))
    lname = 'f%s' % fnum
    for fid in range(fnum):
        left = dict.fromkeys(uniques, 0)
        right = dict(zip(uniques, counts))
        fname = 'f%s' % fid
        last_label = None
        for fval, label in np.sort(D[[fname, lname]], order=fname, axis=0).flatten():
            left[label] += 1
            right[label] -= 1
            if last_label == label:
                continue
            last_label = label
            left_values, right_values = left.values(), right.values()
            left_num, right_num = float(sum(left_values)), float(sum(right_values))
            if right_num == 0:
                continue
            left_impurity = 1 - sum([(i / left_num) ** 2 for i in left_values])
            right_impurity = 1 - sum([(i / right_num) ** 2 for i in right_values])
            current_impurity = (left_impurity * left_num + right_impurity * right_num) / Ynum
            if current_impurity < min_impurity:
                min_impurity = current_impurity
                predicate = fid, fval
    return predicate

In [822]:
def find_predicate_by_optimal_compares(D):
    No = float(D.shape[0])
    Nf = self_n_features
    uniques, counts = np.unique(D['label'], return_counts=True)
    min_impurity = 1 - sum([(i / No) ** 2 for i in counts])
    if min_impurity == 0.:
        return None
    predicate = None
    for fid in range(Nf):
        left = dict.fromkeys(uniques, 0)
        right = dict(zip(uniques, counts))
        fname = 'f%s' % fid
        Dp = np.sort(D[[fname, 'label']], order=fname, axis=0)
        for label, objs in itertools.groupby(Dp.flatten(), lambda x: x[1]):
            objs = list(objs)
            left[label] += len(objs)
            right[label] -= len(objs)
            left_values, right_values = left.values(), right.values()
            left_num, right_num = float(sum(left_values)), float(sum(right_values))
            if right_num == 0:
                continue
            left_impurity = 1 - sum([(i / left_num) ** 2 for i in left_values])
            right_impurity = 1 - sum([(i / right_num) ** 2 for i in right_values])
            current_impurity = (left_impurity * left_num + right_impurity * right_num) / No
            if current_impurity < min_impurity:
                min_impurity = current_impurity
                predicate = fid, objs.pop()[0]
    return predicate

In [553]:
X, Y = datasets.make_classification(n_samples=1000, n_features=10)

In [530]:
%timeit find_predicate_by_bruteforce(X, Y)

1 loops, best of 3: 2.69 s per loop


In [521]:
%timeit find_predicate_by_minimization(X, Y)

1 loops, best of 3: 956 ms per loop


In [823]:
dt = [('f%s' % i, 'f8') for i in range(X.shape[1])]
dt.append(('label', 'f8'))
D = np.concatenate((X, Y.reshape(Y.shape[0], 1)), axis=1)
D = D.view(dt)

%timeit find_predicate_by_optimal_compares(D)

10 loops, best of 3: 31.8 ms per loop
