In [1]:
class Node:
    attributeNames = {1:"leftWeight", 2:"leftDistance", 3:"rightWeight", 4:"rightDistance"}
    
    def __init__(self, sample, features):
        self.sample = sample
        self.features = features
        self.leftChild = None
        self.rightChild = None
        self.chosenFeature = None
        self.attributeValue = None
        
    def findMostCommonClass(self):
        classes =  {"R":0, "L":0, "B":0}
        for i in range(len(self.sample)):
            classes[self.sample[i][0]] += 1
        
        return max(classes, key=classes.get)
            
    
    def setQuestion(self, chosenFeature, attributeValue):
        self.chosenFeature = chosenFeature
        self.attributeValue = attributeValue
    
    def __str__(self, level = 0):
        if self.chosenFeature == None:
            return ""
        
        ret = "\t" * level + repr(self) + "\n"# + "\t" * level + str(self.sample[:5]) + "\n"
        ret += self.leftChild.__str__(level + 1)
        ret += self.rightChild.__str__(level + 1)
        return ret
    
    def __repr__(self):
        return Node.attributeNames[self.chosenFeature] + "<=" + str(self.attributeValue)

In [2]:
from copy import deepcopy

class DecisionTree:
    def __init__(self, data, features, classes):
        self.data = data
        self.features = features # [1,2,3,4]
        self.classes = classes
        
        self.root = Node(data, features)
    
    def getRoot(self):
        return self.root
    
    def classify(self, dataRow):
        currentNode = self.root
        
        while currentNode.chosenFeature != None:
            if dataRow[currentNode.chosenFeature] <= currentNode.attributeValue:
                currentNode = currentNode.leftChild
            else:
                currentNode = currentNode.rightChild
                
        return currentNode.findMostCommonClass()
            
    
    def generateTree(self, node):
        
        sample = node.sample
        if(len(node.features) == 0 or self.uniqueClass(sample)):
            return
        
        giniNode = self.giniSample(sample)
        leftNode = None
        rightNode = None
        bestGiniImpurity = 1 
        bestAttributeValue = None
        bestFeature = None
        
        for feature in node.features:       
            for attributeValue in range(1,5):
                
                # Split sample in 2 subsets
                positiveSample = []
                negativeSample = []
                
                for row in sample:
                    if row[feature] <= attributeValue:
                        positiveSample.append(row)
                    else:
                        negativeSample.append(row)

                # Compute gini impurity of chosen attribute
                giniImpurity = self.giniAttribute(positiveSample, negativeSample)
                
                # Check if attribute gini is worse than parent node, in that case stop
                if(giniImpurity >= giniNode):
                    return
                
                # Save best found attribute split
                if(giniImpurity < bestGiniImpurity):
                    newFeatures = deepcopy(node.features)
                    newFeatures.remove(feature)
                    leftNode = Node(positiveSample, newFeatures)
                    rightNode = Node(negativeSample, newFeatures)
                    bestGiniImpurity = giniImpurity
                    bestFeature = feature
                    bestAttributeValue = attributeValue
                    
        
        node.setQuestion(bestFeature, bestAttributeValue)
        node.leftChild = leftNode
        node.rightChild = rightNode
        
        self.generateTree(node.leftChild)
        self.generateTree(node.rightChild)
             
    
    
    def giniAttribute(self, positiveSample, negativeSample):
        totalSampleSize = len(positiveSample) + len(negativeSample) 
        return len(positiveSample)/totalSampleSize * self.giniSample(positiveSample) + len(negativeSample)/totalSampleSize * self.giniSample(negativeSample)
    
    
    def uniqueClass(self, sample):
        if(len(sample) == 0):
            return True
        
        c = sample[0][0]
        
        for i in range(len(sample)):
            if sample[i][0] != c:
                return False
            
        return True
    
    def giniSample(self, sample):
        g = 1
        sampleSize = len(sample)
        
        if(sampleSize == 0):
            return 0
        
        for x in self.classes:
            samplesOfClassX = 0
            for i in range(sampleSize):
                if sample[i][0] == x:
                    samplesOfClassX += 1
                    
            g -= (samplesOfClassX / sampleSize) ** 2
            
        return g


In [7]:
import csv
import random

class Controller:
    def __init__(self):
        self.testData = []
    
    
    def readDataFromFile(self, filename):
        data = []

        with open(filename) as csv_file:
            csv_reader = csv.reader(csv_file, delimiter=',')
            for row in csv_reader:
                data.append([str(row[0]), int(row[1]), int(row[2]), int(row[3]), int(row[4])])

        return data
    
    def createDecisionTree(self, filename, trainingDataSize):
        data = self.readDataFromFile(filename)
        random.shuffle(data)
        
        trainingData = data[:trainingDataSize]
        self.testData = data[:100]
        
        dt = DecisionTree(trainingData, [1,2,3,4], ['R','B','L'])

        dt.generateTree(dt.getRoot())
        
        counter = 0
        for row in trainingData:
            if row[0] == dt.classify(row):
                counter += 1
        
        return dt, counter/len(trainingData) * 100
    
    def testDecisionTree(self, dt):
        counter = 0
        for row in self.testData:
            if row[0] == dt.classify(row):
                counter += 1
        
        return counter/len(self.testData) * 100
        
    
    

In [9]:
cont = Controller()

print("Data sample size:" + str(len(cont.readDataFromFile("balance-scale.data"))))

dt, accuracy = cont.createDecisionTree("balance-scale.data", 625)
testAccuracy = cont.testDecisionTree(dt)

print("Training Accuracy: " + str(accuracy) + "%")
print("Testing Accuracy:" + str(testAccuracy)+ "%")

print("\n" + str(dt.getRoot()))


Data sample size:625
Training Accuracy: 82.72%
Testing Accuracy:88.0%

leftWeight<=2
	rightWeight<=1
		leftDistance<=2
			rightDistance<=2
			rightDistance<=3
		rightDistance<=1
			leftDistance<=2
			leftDistance<=3
	leftDistance<=2
		rightWeight<=2
			rightDistance<=2
			rightDistance<=1
		rightWeight<=3
			rightDistance<=4
			rightDistance<=3

