In [112]:
from collections import Counter
import math
import pandas as pd
import random
import numpy as np

In [2]:
class Node:
	"""
	A Node class for the ID3 decision tree. This is not to be used directly,
	please use the DecisionTree class instead.
	"""
	def __init__(self, data: pd.DataFrame, Y_attr: str, split_con='None'):
		"""
		Creates a new Node instance.
		
		Args:
		-----
		data: The data to be contained in this node
		Y_attr: The string representing the Y attribute
		split_con: The attribute that this data was split on
		"""
		self.data = data
		self.y_attr = Y_attr
		self.split_condition = split_con
		self.children = []

	def get_entropy(self, data: pd.DataFrame) -> float:
		"""
		Gets the entropy of the data.
		
		Args:
		-----
		data: The data to find the entropy of
		"""
		cnt = Counter(data[self.y_attr])
		probs = [x / len(data.index) for x in cnt.values()]
		return sum([-p * math.log(p, 2) for p in probs])
	
	def get_split_condition(self) -> tuple:
		"""
		Returns the attribute, which on splitting by yields the highest
		information gain.
		"""
		if len(np.unique(self.data[self.y_attr])) == 1:
			return None, 0
        
		best_split = None
		info_gain = 0
		par_entropy = self.get_entropy(self.data)

		for col in self.data.columns:
			if col != self.y_attr:
				groups = list(self.data.groupby(col))
				for _, group in groups:
					entropy = self.get_entropy(group)
					cur_info_gain = par_entropy - len(group.index) / len(self.data.index) * entropy
					if cur_info_gain > info_gain:
						best_split = col
						info_gain = cur_info_gain

		return best_split, info_gain
	
	def split(self, verbose=True) -> None:
		"""
		Splits the data in the current node by the best split condition.
		
		Args:
		-----
		verbose: Prints debug information
		"""
		split, info_gain = self.get_split_condition()

		if split is None:
			if verbose:
				print('Found leaf node.')
            
			n = Node(self.data, self.y_attr, None)

		if verbose:
			print('Splitting on', split, 'with information gain', info_gain)
		
		# Split the data by the condition
		groups = list(self.data.groupby(split))

		if verbose:
			print('Children:')

		for _, group in groups:
			# Remove the split condition column and create a node from the
			# resulting dataset
			group.drop(split, axis=1, inplace=True)

			if verbose:
				print('---------\n', group)

			n = Node(group, self.y_attr, split)
			self.children.append(n)

In [121]:
guess_count = 0

In [120]:
class DecisionTree(Node):
	"""
	A DecisionTree class that implements the ID3 algorithm using the Node
	class.	
	"""
	def __init__(self, data: pd.DataFrame, y: str):
		"""
		Creates a DecisionTree object.
		
		Args:
		-----
		data: The data for the current node
		y: The output attribute
		"""
		super().__init__(data, y)

	def fit(self) -> None:
		"""
		Creates the full decision tree from the current data.
		"""
		stack = [self]

		while len(stack) > 0:
			node = stack.pop()
			
			# If entropy is 0, then stop splitting.
			if node.get_entropy(node.data) > 0:
				node.split(verbose=False)
				for child in node.children:
					stack.append(child)

	def print(self) -> None:
		"""
		Prints the decision tree nodes' data.
		"""
		level = 0
		stack = [(level, self)]

		while len(stack) > 0:
			level, node = stack.pop()
			print('\nLevel', level, 'Split condition:', node.split_condition, 
				'\n-----------')
			print(node.data)

			for child in node.children:
				stack.append((level + 1, child))

	def predict(self, sample: pd.DataFrame) -> str:
		"""
		Returns the class label for the given sample.
		
		Args:
		-----
		sample: A DataFrame containing a single sample to predict on
		"""
		node = self

		# Guess randomly if looping infinitely
		it_counter = 0

		while len(node.children) > 0:
			# Get the current splitting condition.
			split_con = node.children[0].split_condition
			
			for child in node.children:
				data = child.data
				
				# Get the first sample in this split
				first_sample = list(data.index)[0]
				
				# Check if this child has the right value of the splitting
				# condition. If not, try another child.
				#print(node.data.index)
				#print(first_sample)
				#print(split_con)
				it_counter += 1
				if sample[split_con][0] == node.data.loc[first_sample,:][split_con]:
					node = child
					it_counter = 0
					break
                
				if it_counter == 5:
					global guess_count
					guess_count += 1
					it_counter = 0
					print('Random guess', guess_count)
					y_uniq = np.unique(node.data[self.y_attr])
					return random.choice(y_uniq)
		
		return list(node.data[self.y_attr])[0]

In [56]:
df = pd.read_csv('connect-4.data')

In [17]:
df.head()

Unnamed: 0,b,b.1,b.2,b.3,b.4,b.5,b.6,b.7,b.8,b.9,...,b.25,b.26,b.27,b.28,b.29,b.30,b.31,b.32,b.33,win
0,b,b,b,b,b,b,b,b,b,b,...,b,b,b,b,b,b,b,b,b,win
1,b,b,b,b,b,b,o,b,b,b,...,b,b,b,b,b,b,b,b,b,win
2,b,b,b,b,b,b,b,b,b,b,...,b,b,b,b,b,b,b,b,b,win
3,o,b,b,b,b,b,b,b,b,b,...,b,b,b,b,b,b,b,b,b,win
4,b,b,b,b,b,b,b,b,b,b,...,b,b,b,o,b,b,b,b,b,win


In [57]:
df = df[df['win'] != 'draw']

In [19]:
df.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 61107 entries, 0 to 67553
Data columns (total 43 columns):
b       61107 non-null object
b.1     61107 non-null object
b.2     61107 non-null object
b.3     61107 non-null object
b.4     61107 non-null object
b.5     61107 non-null object
b.6     61107 non-null object
b.7     61107 non-null object
b.8     61107 non-null object
b.9     61107 non-null object
b.10    61107 non-null object
b.11    61107 non-null object
x       61107 non-null object
o       61107 non-null object
b.12    61107 non-null object
b.13    61107 non-null object
b.14    61107 non-null object
b.15    61107 non-null object
x.1     61107 non-null object
o.1     61107 non-null object
x.2     61107 non-null object
o.2     61107 non-null object
x.3     61107 non-null object
o.3     61107 non-null object
b.16    61107 non-null object
b.17    61107 non-null object
b.18    61107 non-null object
b.19    61107 non-null object
b.20    61107 non-null object
b.21    61107 non-nul

In [6]:
from sklearn.model_selection import train_test_split

In [58]:
train, test = train_test_split(df, train_size=0.7)

In [59]:
test = test.reset_index()

In [67]:
for row in test.iterrows():
    print(pd.DataFrame(row[1]).T)
    break

   index  b b.1 b.2 b.3 b.4 b.5 b.6 b.7 b.8  ...  b.25 b.26 b.27 b.28 b.29  \
0  27917  b   b   b   b   b   b   b   b   b  ...     b    b    b    x    b   

  b.30 b.31 b.32 b.33   win  
0    b    b    b    b  loss  

[1 rows x 44 columns]


In [22]:
train.info()

<class 'pandas.core.frame.DataFrame'>
Int64Index: 42774 entries, 57721 to 46633
Data columns (total 43 columns):
b       42774 non-null object
b.1     42774 non-null object
b.2     42774 non-null object
b.3     42774 non-null object
b.4     42774 non-null object
b.5     42774 non-null object
b.6     42774 non-null object
b.7     42774 non-null object
b.8     42774 non-null object
b.9     42774 non-null object
b.10    42774 non-null object
b.11    42774 non-null object
x       42774 non-null object
o       42774 non-null object
b.12    42774 non-null object
b.13    42774 non-null object
b.14    42774 non-null object
b.15    42774 non-null object
x.1     42774 non-null object
o.1     42774 non-null object
x.2     42774 non-null object
o.2     42774 non-null object
x.3     42774 non-null object
o.3     42774 non-null object
b.16    42774 non-null object
b.17    42774 non-null object
b.18    42774 non-null object
b.19    42774 non-null object
b.20    42774 non-null object
b.21    42774 non

In [25]:
root = DecisionTree(train, 'win')

In [26]:
root.fit()

KeyboardInterrupt: 

In [93]:
import pickle
from tqdm import tqdm_notebook

In [122]:
with open('tree_model.pkl', 'rb') as f:
    root = pickle.load(f)

In [123]:
predictions = []

for i, row in tqdm_notebook(test.iterrows(), total=test.shape[0]):
    df = pd.DataFrame(row).T
    df.index = [0]
    predictions.append(root.predict(df))

HBox(children=(IntProgress(value=0, max=18333), HTML(value='')))

Random guess 1
Random guess 2
Random guess 3
Random guess 4
Random guess 5
Random guess 6
Random guess 7
Random guess 8
Random guess 9
Random guess 10
Random guess 11
Random guess 12
Random guess 13
Random guess 14
Random guess 15
Random guess 16
Random guess 17
Random guess 18
Random guess 19
Random guess 20
Random guess 21
Random guess 22
Random guess 23
Random guess 24
Random guess 25
Random guess 26
Random guess 27
Random guess 28
Random guess 29
Random guess 30
Random guess 31
Random guess 32
Random guess 33
Random guess 34
Random guess 35
Random guess 36
Random guess 37
Random guess 38
Random guess 39
Random guess 40
Random guess 41
Random guess 42
Random guess 43
Random guess 44
Random guess 45
Random guess 46
Random guess 47
Random guess 48
Random guess 49
Random guess 50
Random guess 51
Random guess 52
Random guess 53
Random guess 54
Random guess 55
Random guess 56
Random guess 57
Random guess 58
Random guess 59
Random guess 60
Random guess 61
Random guess 62
Random guess 63
R

In [124]:
predictions[:5]

['loss', 'win', 'win', 'loss', 'win']

In [131]:
from sklearn.metrics import accuracy_score, classification_report

In [130]:
accuracy_score(predictions, test.iloc[:,-1])

0.9412534773359515

In [133]:
print(classification_report(predictions, test.iloc[:,-1]))

              precision    recall  f1-score   support

        loss       0.88      0.90      0.89      4965
         win       0.96      0.96      0.96     13368

   micro avg       0.94      0.94      0.94     18333
   macro avg       0.92      0.93      0.93     18333
weighted avg       0.94      0.94      0.94     18333

