In [1]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier

In [12]:
class WeakLearner:
    def __init__(self, model, i):
        self.__class = i
        self.__model = model
        self.miss_data = None
        self.error_rate = None
        self.alpha=0
    
    def __sign(self, val):
        return 1 if val > 0 else -1
    
    def name(self):
        return self.__name
    
    def model(self):
        return self.__model
    
    def miss_classify(self, data, eval_data):
        self.miss_data = []
        self.predictions = self.__model.predict([d.data for d in data])
        for i in range(len(self.predictions)):
            if self.__sign(self.predictions[i]) != self.__sign(eval_data[i]):
                self.miss_data.append(i)
        
    def calc_error_rate(self, w):
        self.error_rate = np.sum(w[self.miss_data])
    
    def calc_voting_power(self):
        print(f"Updataing on {self.error_rate}")
        self.alpha = 1/2*np.log((1-self.error_rate)/self.error_rate)
        

In [3]:
class DataPoint:
    def __init__(self, vals):
        self.id = vals[0]
        self.data = vals[1]
        self.weight_ = None
    
    def update_weight(self,best):
        if self.id not in best.miss_data:
            self.weight_ = 1/2*(1/(1 - best.error_rate))*self.weight_
        else:
            self.weight_ = 1/2*(1/best.error_rate)*self.weight_

In [4]:
def eval_WL(wl, i):
    return np.sign(np.sum([h.alpha*h.predictions[i] for h in wl]))

In [5]:
def WL_accuracy(wl, data):
    length = len(data)
    summation = 0
    for i in range(length):
        summation += eval_WL(wl, i)
    return summation/length

In [6]:
def ShallowTree():
    return DecisionTreeClassifier(max_depth=2)

In [7]:
def classify(data, classification):
    return [1 if np.where(d == 1)[0][0] == classification else -1 for d in data]

# Local tests

In [8]:
import glob
from PIL import Image
import os
from sklearn.model_selection import train_test_split

IMAGE_DIR = "./data/data/data"

def load():
    file_list = glob.glob(IMAGE_DIR + "/*.jpg")
    X = []
    Y = []

    for fname in file_list:
        with Image.open(fname) as img:
            np_img = np.array(img).flatten()
        label = int(os.path.split(fname)[-1].split('.')[0].split('_')[3])-1

        X.append(np_img)
        tempy = np.zeros(15)
        tempy[label] = 1
        Y.append(tempy)
    X, Y = np.array(X), np.array(Y)
    return X, Y


In [9]:
# create test, train split
X, Y = load()
Xtrain, Xtest, Ytrain, Ytest = train_test_split(X, Y, train_size=0.8, random_state=2021)
Ytrain_classes = []
for i in range(Ytrain.shape[1]):
    Ytrain_0 = classify(Ytrain, i)
    Ytrain_classes.append(Ytrain_0)
    print(f"{Ytrain_0.count(1)} datapoints have classification {i}")

791 datapoints have classification 0
806 datapoints have classification 1
812 datapoints have classification 2
806 datapoints have classification 3
798 datapoints have classification 4
781 datapoints have classification 5
805 datapoints have classification 6
832 datapoints have classification 7
814 datapoints have classification 8
785 datapoints have classification 9
792 datapoints have classification 10
799 datapoints have classification 11
781 datapoints have classification 12
800 datapoints have classification 13
798 datapoints have classification 14


In [15]:
data = [DataPoint((i, Xtrain[i])) for i in range(Xtrain.shape[0])]
for d in data:
    d.weight_ = 1/1000
H = [WeakLearner(ShallowTree(), i) for i in range(10)]
for h in H:
    h.model().fit(Xtrain, Ytrain_classes[0])
for i in range(10):
    print(np.round([d.weight_ for d in data],3))
    for h in H:
        h.miss_classify(data, Ytrain_classes[0])
        h.calc_error_rate(np.array([d.weight_ for d in data]))
    best = H[0]
    optimal = 0
    for num, h in enumerate(H):
        if h.error_rate < best.error_rate:
            best = h
            optimal = num
    for d in data:
        d.update_weight(best)
    best.calc_voting_power()
    H.append(best)
    accuracy = WL_accuracy(H, data)
    print(f"{i}: {num} - {best.alpha}, {accuracy}")

[0.001 0.001 0.001 ... 0.001 0.001 0.001]
Updataing on 0.6950000000000002
0: 9 - -0.4118000344786906, 0.9796666666666667
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.49999999999999994
1: 10 - 1.1102230246251564e-16, -0.9796666666666667
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.49999999999999994
2: 11 - 1.1102230246251564e-16, -0.9796666666666667
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
3: 12 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
4: 13 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
5: 14 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
6: 15 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
7: 16 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
8: 17 - 0.0, 0.0
[0.002 0.002 0.002 ... 0.002 0.002 0.002]
Updataing on 0.5
9: 18 - 0.0, 0.0
