In [1]:
import numpy as np
import numpy.random as rand

In [2]:
from torchvision.datasets import MNIST

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
train = MNIST("data", train=True, download=True)
test = MNIST("data", train=False, download=True)

In [19]:
foo = train.data.flatten(start_dim=1)
foo.size()

torch.Size([60000, 784])

In [15]:
train.targets.size()

torch.Size([60000])

In [130]:
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier

# X, y = load_digits(return_X_y=True)
# X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=2023)

# model = RandomForestClassifier(
#     n_estimators=100, max_depth=38, random_state=2023, min_samples_split=0.01, min_samples_leaf=25,
#     max_features=1
# )
# Max Features
crit = "gini"
n_samples = 100
model = RandomForestClassifier(criterion=crit, n_estimators=10, random_state=420, max_depth=5, min_samples_leaf=1, 
                               max_features=25, min_samples_split=2,ccp_alpha=0.01, min_impurity_decrease=0.01)
model.fit(train.data[0:n_samples, :, :].flatten(start_dim=1), train.targets[0:n_samples])

model.score(test.data[:, :, :].flatten(start_dim=1), test.targets[:])

0.559

In [30]:
NUM_DIM = 5
LIMIT = 4

class Solutions:
    def __init__(self, num):
        self.solutions = np.stack([Solutions.new_solution() for _ in range(num)])
        self.employed = np.array([True] * (num // 2) + [False] * (num // 2))
        self.failures = np.zeros_like(self.employed, np.int32)
        self.onlooker = np.logical_not(self.employed) # Unemployed
        self.best_fitness = -1
        self.best_sol = None 

    @staticmethod
    def new_solution():
        return rand.random(size=NUM_DIM)

    def random_sol(self, exclude=-1):
        if exclude == -1:
            return rand.choice(self.solutions)
        else:
            rand_idx = rand.randint(0, self.solutions.shape[0])
            if rand_idx == exclude:
                return self.random_sol(exclude=exclude)
            else:
                return self.solutions[rand_idx]

    def get_employed(self):
        return self.solutions[self.employed]

    def get_unemployed(self):
        return self.solutions[np.logical_not(self.employed)]
        
    def get_onlooker(self):
        return self.solutions[self.onlooker]

    def get_scout(self):
        return self.solutions[self.scout]

    def fitness(self, x):
        # Todo return some positive number
        fit = np.abs(x).sum()
        return fit

    def most_fit(self):
        fit = np.array([self.fitness(x) for x in self.solutions])
        idx = fit.argmax()
        return fit[idx], self.solutions[idx]

    def update_best(self):
        best_fit, best_sol = self.most_fit()
        if best_fit > self.best_fitness:
            self.best_sol = best_sol
            self.best_fitness = best_fit

def basic_employed(sol: Solutions, initial_idx: int):
    initial = sol.solutions[initial_idx]
    a = 0.1 # Todo figure this out
    idx = rand.randint(0, initial.size)
    phi = rand.uniform(low=-a, high=a)
    out = np.copy(initial)
    sol_k = sol.random_sol(exclude=initial_idx)
    out[idx] += phi * (out[idx] - sol_k[idx])
    # Todo make sure values stay within expected range
    return out # Greedy select this

def basic_onlooker(sol: Solutions):
    employed = sol.get_employed()
    fitnesses = np.array([sol.fitness(x) for x in employed]) # Todo figure out how to cache fitness
    total_fitness = np.sum(fitnesses)
    bee_idx = rand.choice(np.arange(len(employed)), p=fitnesses/total_fitness)
    return basic_employed(sol, bee_idx)

def vanilla_abc(num_bees, epoches):
    # init_bees()
    sol = Solutions(num_bees)
    for _ in range(epoches):
        # Employed
        for idx in sol.employed.nonzero()[0]:
            candidate = basic_employed(sol, idx)
            if sol.fitness(candidate) > sol.fitness(sol.solutions[idx]):
                sol.solutions[idx] = candidate
                sol.failures[idx] = 0
            else:
                sol.failures[idx] += 1
        # Onlooker
        for idx in sol.onlooker.nonzero()[0]:
            candidate = basic_onlooker(sol)
            if sol.fitness(candidate) > sol.fitness(sol.solutions[idx]):
                sol.solutions[idx] = candidate
        # Scout
        for idx in sol.employed.nonzero()[0]:
            if sol.failures[idx] >= LIMIT:
                sol.failures[idx] = 0
                sol.solutions[idx, :] = Solutions.new_solution()
        # Mark best
        sol.update_best()
        print(sol.best_fitness)
        print(sol.best_sol)


vanilla_abc(100, 10)

4.006490358609521
[0.9208981  0.95892401 0.40608905 0.99211276 0.72846643]
4.133913498913904
[0.9208981  0.95892401 0.53351219 0.99211276 0.72846643]
4.133913498913904
[0.9208981  0.95892401 0.53351219 0.99211276 0.72846643]
4.133913498913904
[0.9208981  0.95892401 0.53351219 0.99211276 0.72846643]
4.133913498913904
[0.9208981  0.95892401 0.53351219 0.99211276 0.72846643]
4.173846084992998
[1.173577   0.71934997 1.01741206 0.45910316 0.8044039 ]
4.2646803318715785
[0.88192051 1.45127139 0.97009402 0.66648024 0.29491417]
4.3406290961836556
[0.88226493 1.45127139 0.97009402 0.66648024 0.37051852]
4.52739583448014
[1.14463601 1.45127139 0.97009402 0.66648024 0.29491417]
4.52739583448014
[1.14463601 1.45127139 0.97009402 0.66648024 0.29491417]
