In [1]:
import random
import numpy as np
import pandas as pd

from sklearn import tree
from pprint import pprint
from keras.utils.np_utils import to_categorical

Using TensorFlow backend.


In [2]:
FILE_NAME = 'data.csv'
MAX_Q = 10

# Data have 5 scales:
#     +1   : True
#     +0.5 : Maybe true
#     0    : Unknown
#     -0.5 : Maybe false
#     -1   : False

df = pd.read_csv(FILE_NAME, index_col=0)
idx2col = {i:col for i, col in enumerate(df.columns)}
idx2row = {i:row for i, row in enumerate(df.index)}
data = df.values
n_questions = len(idx2col)
n_plates = len(idx2row)

print('Data information:')
print('Number of plates: %d' % n_plates)
print('Number of questions: %d' % n_questions)

Data information:
Number of plates: 24
Number of questions: 19


In [3]:
# Util functions
def is_unique(idx, data):
    for i in range(len(data)):
        if i != idx and np.all(data[idx] == data[i]):
            return False
    
    return True

In [4]:
# Build decision tree
X = data
y = np.arange(n_plates)

clf = tree.DecisionTreeClassifier(max_depth=20, random_state=0, criterion='entropy')
clf = clf.fit(X, y)

In [6]:
# Check decision paths
idx = random.choice(range(n_plates))
t = [clf.tree_.feature[i] for i in clf.decision_path([X[idx]]).indices][:-1]

print('Plate: %s\n\nQuestions:' % idx2row[idx])
pprint([idx2col[i] for i in t])
print('\nCheck uniqueness: %s' % is_unique(idx, data[:, t]))

Plate: Causa rellena de pollo

Questions:
['Lleva papas?', 'Lleva pollo?', 'Es caliente?']

Check uniqueness: True


In [7]:
# Get paths for each plate
plate_paths = []

for idx in range(len(X)):
    path = [clf.tree_.feature[i] for i in clf.decision_path([X[idx]]).indices][:-1]
    plate_paths.append(path)

# pprint(plate_paths)

In [15]:
for i in range(1, len(plate_paths[0])):
    print(plate_paths[0][:i])
    
print(plate_paths[0][:i+1])

[15]
[15, 11]
[15, 11, 3]
[15, 11, 3, 17]


In [27]:
# Train data generator
def generator(data, max_q, n_q, n_p):
    # TODO: lyer behaviour
    # TODO: unoptimize path
    # TODO: use guess on paths (0.5 or -0.5)
    while True:
        t_x = []
        t_y_q = []
        t_y_p = []
        t_y_c = []
        
        for plate, d in enumerate(data):
            for i in range(1, len(d)):
                t_x.append(d[:i])
                t_y_q.append(d[i])
                t_y_p.append(-1)
                t_y_c.append(0)
                
            t_x.append(d[:i+1])
            t_y_q.append(-1)
            t_y_p.append(plate)
            t_y_c.append(1)
            
        n = len(t_y_q)
        out_x = np.zeros((n, max_q))
        out_y_q = np.zeros((n, n_q))
        out_y_p = np.zeros((n, n_p))
        out_y_c = np.zeros((n, 2))
        
        for i in range(len(t_x)):
            out_x[i, :len(t_x[i])] = np.asarray(t_x[i])
            out_y_c[i] = to_categorical(t_y_c[i], 2)
            
            if t_y_q[i] > -1:
                out_y_q[i] = to_categorical(t_y_q[i], n_q)
            
            if t_y_p[i] > -1:
                out_y_p[i] = to_categorical(t_y_p[i], n_p)
        
        yield out_x, out_y_q, out_y_p, out_y_c

In [28]:
gen_test = generator(plate_paths, MAX_Q, n_questions, n_plates)

for i, (t_x, t_y_q, t_y_p, t_y_c) in enumerate(gen_test):
    if i == 1:
        break
    
    print(t_x[:4])
    print(t_y_q[:4])
    print(t_y_p[:4])
    print(t_y_c[:4])

[[ 15.   0.   0.   0.   0.   0.   0.   0.   0.   0.]
 [ 15.  11.   0.   0.   0.   0.   0.   0.   0.   0.]
 [ 15.  11.   3.   0.   0.   0.   0.   0.   0.   0.]
 [ 15.  11.   3.  17.   0.   0.   0.   0.   0.   0.]]
[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  0.  0.  0.  0.  0.  0.
   0.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.
   0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.]]
[[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  0.
   0.  0.  0.  0.  0.  0.]]
[[ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]
 [ 0.  1.]]
