In [3]:
import numpy as np

In [4]:
def calc_shannon_ent(dataset):
    item_num = len(dataset)
    labels = {}
    for item in dataset:
        label = item[-1]
        labels[label] = labels.get(label, 0) + 1
    shannon_ent = 0.0
    for key in labels:
        prob = labels[key] / item_num
        shannon_ent -= prob  * np.log2(prob)
    
    return shannon_ent

In [47]:
def create_dataset():
    data_set = [[1, 1, 'yes'],
               [1, 1, 'yes'],
               [1, 0, 'no'],
               [0, 1, 'no'],
               [0, 1, 'no']]
    
    labels = ['no surfacing', 'flippers']
    return data_set, labels

In [6]:
def split_dataset(dataset, axis, value):
    features = []
    for item in dataset:
        if item[axis] == value:
            vec = item[:axis]
            vec.extend(item[axis+1:])
            features.append(vec)
    return features

In [193]:
def choose_best_feature(dataset):
    num_features = len(dataset[0]) - 1
    
    ent_gain = 0.0
    best_i = 0
    for i in range(num_features):
        feature_i = [x[i] for x in dataset]
        unique_val = set(feature_i)
        
        ent_val = 0.0
        for val in unique_val:
            subset = split_dataset(dataset, i, val)
            prob = len(subset) / len(dataset)
            ent_val += prob * calc_shannon_ent(subset)
        
        new_ent_gain = calc_shannon_ent(dataset) - ent_val
        
#         print(i, new_ent_gain)
        if new_ent_gain > ent_gain:
            ent_gain = new_ent_gain
            best_i = i
#     print(best_i)
    return best_i

In [28]:
def choose_max_label(dataset):
    label_cnt ={}
    for item in dataset:
        label_cnt[item[-1]] = label_cnt.get(item[-1], 0) + 1
    
    return max(label_cnt, key=label_cnt.get)

In [194]:
def create_tree(dataset, labels):
    class_list = [x[-1] for x in dataset]
    
    #标签相同时，返回标签
    unique_class = set(class_list)
    if len(unique_class) <= 1:
        return class_list[0]
    
    #特征为空集时，返回最大标签
    if len(dataset[0]) <= 1:
        return choose_max_label(dataset)
    
    best_i  = choose_best_feature(dataset)
    best_label = labels[best_i]
    
#     print(best_label)
    my_tree = {best_label:{}}
    
    feature_val = [x[best_i] for x in dataset]
    unique_val = set(feature_val)
    
    for val in unique_val:
        sublabels = labels[:best_i] + labels[best_i+1:]
        my_tree[best_label][val] = create_tree(split_dataset(dataset, best_i, val), sublabels)
    return my_tree

In [166]:
def classify(input_tree, feature_labels, test_vec):
    first_str = list(input_tree.keys())[0]
    second_dict = input_tree[first_str]
#     print(first_str)
#     print(feature_labels)
    feat_index = feature_labels.index(first_str)
    
    for key in second_dict:
        if test_vec[feat_index] == key:
            if type(second_dict[key]).__name__ == 'dict':
                class_label = classify(second_dict[key], feature_labels, test_vec)
            else:
                class_label = second_dict[key]
    
    return class_label

In [167]:
my_data, labels = create_dataset()

In [68]:
labels

['no surfacing', 'flippers']

In [198]:
my_data, labels = create_dataset()
tree = create_tree(my_data, labels)
classify(tree, labels, [1, 0])

'no'

In [199]:
fr = open('lenses.txt', 'r')
dataset = [line.strip().split('\t') for line in fr.readlines()]
labels = ['age', 'prescript', 'astigmatic', 'tearRate']

In [200]:
tree = create_tree(dataset, labels)

In [201]:
tree

{'tearRate': {'normal': {'astigmatic': {'no': {'age': {'young': 'soft',
      'presbyopic': {'prescript': {'hyper': 'soft', 'myope': 'no lenses'}},
      'pre': 'soft'}},
    'yes': {'prescript': {'hyper': {'age': {'young': 'hard',
        'presbyopic': 'no lenses',
        'pre': 'no lenses'}},
      'myope': 'hard'}}}},
  'reduced': 'no lenses'}}